From 7c0c1b922e76ce033f645be37f643ff0f3086b7f Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Wed, 8 Apr 2026 15:51:05 -0700 Subject: [PATCH 01/26] Initial commit --- cmake/onnxruntime_unittests.cmake | 3 +- .../webgpu_matmul_nbits_decode.cc | 237 ++++++++++++++++++ 2 files changed, 239 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 280ec829c268d..55d930d74c40a 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1355,7 +1355,8 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) ${BENCHMARK_DIR}/activation.cc ${BENCHMARK_DIR}/quantize.cc ${BENCHMARK_DIR}/reduceminmax.cc - ${BENCHMARK_DIR}/layer_normalization.cc) + ${BENCHMARK_DIR}/layer_normalization.cc + ${BENCHMARK_DIR}/webgpu_matmul_nbits_decode.cc) target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} ${ONNXRUNTIME_ROOT}/core/mlas/inc) target_compile_definitions(onnxruntime_benchmark PRIVATE BENCHMARK_STATIC_DEFINE) target_compile_definitions(onnxruntime_benchmark PRIVATE ${mlas_private_compile_definitions}) diff --git a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc new file mode 100644 index 0000000000000..bc1e52e3ba90a --- /dev/null +++ b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc @@ -0,0 +1,237 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include +#include +#include + +#include +#include +#include + +extern OrtEnv* env; +extern const OrtApi* g_ort; + +namespace { + +struct DecodeBenchConfig { + int64_t n; + int64_t k; + int64_t bits; + int64_t block_size; + int64_t accuracy_level; +}; + +template +void AddTensorInitializer(ONNX_NAMESPACE::GraphProto& graph, + const std::string& name, + int32_t data_type, + const std::vector& dims, + const std::vector& values) { + auto* initializer = graph.add_initializer(); + initializer->set_name(name); + initializer->set_data_type(data_type); + for (int64_t dim : dims) { + initializer->add_dims(dim); + } + + initializer->set_raw_data(values.data(), values.size() * sizeof(T)); +} + +std::vector GetDecodeBenchConfigs() { + // Each entry is {N, K, bits, block_size, accuracy_level} for a decode-style M=1 run. + return { + {5120, 3072, 4, 32, 4}, + {8192, 3072, 4, 32, 4}, + {3072, 8192, 4, 32, 4}, + {200064, 3072, 4, 32, 4}, + }; +} + +void AddMatMulNBitsNode(ONNX_NAMESPACE::GraphProto& graph, + const std::string& node_name, + const std::string& input_name, + const std::string& weight_name, + const std::string& scale_name, + const std::string& bias_name, + const std::string& output_name, + int64_t k, + int64_t n, + int64_t bits, + int64_t block_size, + int64_t accuracy_level) { + auto* node = graph.add_node(); + node->set_name(node_name); + node->set_op_type("MatMulNBits"); + node->set_domain("com.microsoft"); + node->add_input(input_name); + node->add_input(weight_name); + node->add_input(scale_name); + node->add_input(""); + node->add_input(""); + node->add_input(bias_name); + node->add_output(output_name); + + auto* attr_k = node->add_attribute(); + attr_k->set_name("K"); + attr_k->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + attr_k->set_i(k); + + auto* attr_n = node->add_attribute(); + attr_n->set_name("N"); + attr_n->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + attr_n->set_i(n); + + auto* attr_bits = node->add_attribute(); + attr_bits->set_name("bits"); + attr_bits->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + attr_bits->set_i(bits); + + auto* attr_block = node->add_attribute(); + attr_block->set_name("block_size"); + attr_block->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + attr_block->set_i(block_size); + + auto* attr_accuracy = node->add_attribute(); + attr_accuracy->set_name("accuracy_level"); + attr_accuracy->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + attr_accuracy->set_i(accuracy_level); +} + +std::vector SerializeMatMulNBitsModel(const DecodeBenchConfig& config) { + const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; + const int64_t blob_size = (config.block_size * config.bits) / 8; + + ONNX_NAMESPACE::ModelProto model; + model.set_ir_version(10); + + auto* onnx_opset = model.add_opset_import(); + onnx_opset->set_domain(""); + onnx_opset->set_version(21); + auto* ms_opset = model.add_opset_import(); + ms_opset->set_domain("com.microsoft"); + ms_opset->set_version(1); + + auto* graph = model.mutable_graph(); + graph->set_name("WebGpuMatMulNBitsDecode"); + + auto* input = graph->add_input(); + input->set_name("A"); + input->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.k); + + auto* output = graph->add_output(); + output->set_name("Y"); + output->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.n); + + std::vector packed_b(static_cast(config.n * k_blocks * blob_size), uint8_t{0x11}); + std::vector scales(static_cast(config.n * k_blocks), Ort::Float16_t(0.03125f)); + std::vector bias(static_cast(config.n), Ort::Float16_t(0.125f)); + + AddTensorInitializer(*graph, "B", ONNX_NAMESPACE::TensorProto_DataType_UINT8, + {config.n, k_blocks, blob_size}, packed_b); + AddTensorInitializer(*graph, "scales", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + {config.n, k_blocks}, scales); + AddTensorInitializer(*graph, "bias", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + {config.n}, bias); + + AddMatMulNBitsNode(*graph, + "MatMulNBitsDecode", + "A", + "B", + "scales", + "bias", + "Y", + config.k, + config.n, + config.bits, + config.block_size, + config.accuracy_level); + + const auto serialized = model.SerializeAsString(); + return std::vector(serialized.begin(), serialized.end()); +} + +static void BM_WebGpuMatMulNBitsDecode(benchmark::State& state) { + const DecodeBenchConfig config{ + state.range(0), + state.range(1), + state.range(2), + state.range(3), + state.range(4), + }; + + if (config.k % config.block_size != 0) { + state.SkipWithError("K must be divisible by block_size for this benchmark skeleton."); + return; + } + + std::vector model_data = SerializeMatMulNBitsModel(config); + + Ort::SessionOptions session_options; + session_options.DisableMemPattern(); + session_options.AppendExecutionProvider("WebGPU", std::unordered_map{}); + + OrtSession* raw_session = nullptr; + OrtStatus* status = g_ort->CreateSessionFromArray(env, model_data.data(), model_data.size(), session_options, &raw_session); + if (status != nullptr) { + state.SkipWithError(g_ort->GetErrorMessage(status)); + g_ort->ReleaseStatus(status); + return; + } + + Ort::Session session{raw_session}; + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + std::vector input_shape{1, config.k}; + std::vector activation(static_cast(config.k)); + + std::mt19937 rng(123); + std::uniform_real_distribution dist(-1.0f, 1.0f); + for (auto& value : activation) { + value = Ort::Float16_t(dist(rng)); + } + + const char* input_names[] = {"A"}; + const char* output_names[] = {"Y"}; + + auto input_tensor = Ort::Value::CreateTensor(memory_info, + activation.data(), + activation.size(), + input_shape.data(), + input_shape.size()); + + for (int i = 0; i < 10; ++i) { + auto warmup_outputs = session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, 1); + benchmark::DoNotOptimize(warmup_outputs); + } + + for (auto _ : state) { + auto outputs = session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, 1); + benchmark::DoNotOptimize(outputs); + } + + const double total_flops = 2.0 * static_cast(config.n) * static_cast(config.k); + + state.SetLabel("fp16_decode_bias"); + state.counters["TFLOPS"] = benchmark::Counter( + total_flops, + benchmark::Counter::kIsIterationInvariantRate); +} + +void ApplyWebGpuMatMulNBitsDecodeArgs(benchmark::internal::Benchmark* benchmark) { + for (const auto& config : GetDecodeBenchConfigs()) { + benchmark->Args({config.n, config.k, config.bits, config.block_size, config.accuracy_level}); + } +} + +BENCHMARK(BM_WebGpuMatMulNBitsDecode) + ->Apply(ApplyWebGpuMatMulNBitsDecodeArgs) + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +} // namespace \ No newline at end of file From a0550b689955d3b64cf5394d07cd8405c9597dbd Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Wed, 8 Apr 2026 17:20:26 -0700 Subject: [PATCH 02/26] More changes --- .../webgpu_matmul_nbits_decode.cc | 212 +++++++++++++++++- 1 file changed, 209 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc index bc1e52e3ba90a..fd2e46c6bbbf4 100644 --- a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc +++ b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc @@ -3,7 +3,10 @@ #include +#include #include +#include +#include #include #include @@ -11,6 +14,9 @@ #include #include +#include +#include + extern OrtEnv* env; extern const OrtApi* g_ort; @@ -24,6 +30,205 @@ struct DecodeBenchConfig { int64_t accuracy_level; }; +struct AdapterSelectionConfig { + // adapter_type: Dawn adapter type to select, e.g. integrated or discrete GPU. + // adapter_index: zero-based index among only adapters of adapter_type. + // context_id: ORT WebGPU custom context ID used to bind the externally created instance/device. + // backend_type: Dawn backend to enumerate adapters from, e.g. D3D12 or Vulkan. + // print_adapter_list: whether to print all discovered adapters before selecting one. + WGPUAdapterType adapter_type; + int adapter_index; + int context_id; + WGPUBackendType backend_type; + bool print_adapter_list; +}; + +struct AdapterCandidate { + dawn::native::Adapter adapter; + int global_index; + WGPUAdapterType adapter_type; + int type_index; + uint32_t vendor_id; + uint32_t device_id; + std::string vendor; + std::string architecture; + std::string device; + std::string description; +}; + +struct SelectedWebGpuContext { + std::unique_ptr dawn_instance; + WGPUInstance instance{nullptr}; + WGPUDevice device{nullptr}; + std::unordered_map provider_options; + std::string selected_adapter_summary; +}; + +std::string ToString(WGPUStringView value) { + return value.data == nullptr ? std::string{} : std::string(value.data, value.length); +} + +const char* AdapterTypeToString(WGPUAdapterType adapter_type) { + switch (adapter_type) { + case WGPUAdapterType_DiscreteGPU: + return "discrete"; + case WGPUAdapterType_IntegratedGPU: + return "integrated"; + case WGPUAdapterType_CPU: + return "cpu"; + default: + return "unknown"; + } +} + +bool IsGpuAdapterType(WGPUAdapterType adapter_type) { + return adapter_type == WGPUAdapterType_DiscreteGPU || + adapter_type == WGPUAdapterType_IntegratedGPU; +} + +std::string FormatAdapterSummary(const AdapterCandidate& adapter) { + std::ostringstream stream; + stream << "adapter[" << adapter.global_index << "]" + << " type=" << AdapterTypeToString(adapter.adapter_type) + << " type_index=" << adapter.type_index + << " vendor=" << adapter.vendor + << " architecture=" << adapter.architecture + << " gpu_name=" << adapter.device + << " description=" << adapter.description + << " vendor_id=" << adapter.vendor_id + << " device_id=" << adapter.device_id; + return stream.str(); +} + +AdapterSelectionConfig GetAdapterSelectionConfig() { + // Pick the second discrete adapter by default so this benchmark can target the + // "other" dGPU on a machine with two discrete GPUs and one integrated GPU. + return { + WGPUAdapterType_DiscreteGPU, // adapter_type + 1, // adapter_index + 1, // context_id + WGPUBackendType_D3D12, // backend_type + true, // print_adapter_list + }; +} + +SelectedWebGpuContext CreateSelectedWebGpuContext() { + const AdapterSelectionConfig config = GetAdapterSelectionConfig(); + + SelectedWebGpuContext selected_context; + selected_context.dawn_instance = std::make_unique(); + + WGPURequestAdapterOptions adapter_options = WGPU_REQUEST_ADAPTER_OPTIONS_INIT; + adapter_options.backendType = config.backend_type; + adapter_options.powerPreference = WGPUPowerPreference_Undefined; + + std::vector adapters = selected_context.dawn_instance->EnumerateAdapters(&adapter_options); + if (adapters.empty()) { + throw std::runtime_error("No Dawn adapters were found for the configured backend."); + } + + std::vector candidates; + candidates.reserve(adapters.size()); + int discrete_index = 0; + int integrated_index = 0; + int cpu_index = 0; + int unknown_index = 0; + for (size_t i = 0; i < adapters.size(); ++i) { + WGPUAdapterInfo info = WGPU_ADAPTER_INFO_INIT; + if (wgpuAdapterGetInfo(adapters[i].Get(), &info) != WGPUStatus_Success) { + continue; + } + + const WGPUAdapterType adapter_type = info.adapterType; + int current_type_index = 0; + switch (adapter_type) { + case WGPUAdapterType_DiscreteGPU: + current_type_index = discrete_index++; + break; + case WGPUAdapterType_IntegratedGPU: + current_type_index = integrated_index++; + break; + case WGPUAdapterType_CPU: + current_type_index = cpu_index++; + break; + default: + current_type_index = unknown_index++; + break; + } + candidates.push_back(AdapterCandidate{ + adapters[i], + static_cast(i), + adapter_type, + current_type_index, + info.vendorID, + info.deviceID, + ToString(info.vendor), + ToString(info.architecture), + ToString(info.device), + ToString(info.description), + }); + + wgpuAdapterInfoFreeMembers(info); + } + + if (config.print_adapter_list) { + std::cout << "Available Dawn GPU adapters for WebGPU benchmark:" << std::endl; + bool printed_gpu = false; + for (const auto& candidate : candidates) { + if (!IsGpuAdapterType(candidate.adapter_type)) { + continue; + } + + printed_gpu = true; + std::cout << " " << FormatAdapterSummary(candidate) << std::endl; + } + + if (!printed_gpu) { + std::cout << " No integrated or discrete GPU adapters were found." << std::endl; + } + } + + const AdapterCandidate* selected_adapter = nullptr; + for (const auto& candidate : candidates) { + if (candidate.adapter_type == config.adapter_type && + candidate.type_index == config.adapter_index) { + selected_adapter = &candidate; + break; + } + } + + if (selected_adapter == nullptr) { + std::ostringstream stream; + stream << "Failed to find " << AdapterTypeToString(config.adapter_type) + << " adapter index " << config.adapter_index + << ". Update GetAdapterSelectionConfig() to match the available adapters listed above."; + throw std::runtime_error(stream.str()); + } + + selected_context.instance = selected_context.dawn_instance->Get(); + selected_context.device = selected_adapter->adapter.CreateDevice(); + if (selected_context.device == nullptr) { + throw std::runtime_error("Failed to create a WGPUDevice for the selected adapter."); + } + + selected_context.selected_adapter_summary = FormatAdapterSummary(*selected_adapter); + std::cout << "Selected Dawn adapter for WebGPU benchmark: " + << selected_context.selected_adapter_summary << std::endl; + + selected_context.provider_options["deviceId"] = std::to_string(config.context_id); + selected_context.provider_options["webgpuInstance"] = std::to_string(reinterpret_cast(selected_context.instance)); + selected_context.provider_options["webgpuDevice"] = std::to_string(reinterpret_cast(selected_context.device)); + selected_context.provider_options["preserveDevice"] = "1"; + selected_context.provider_options["dawnProcTable"] = std::to_string(reinterpret_cast(&dawn::native::GetProcs())); + + return selected_context; +} + +const SelectedWebGpuContext& GetSelectedWebGpuContext() { + static const SelectedWebGpuContext selected_context = CreateSelectedWebGpuContext(); + return selected_context; +} + template void AddTensorInitializer(ONNX_NAMESPACE::GraphProto& graph, const std::string& name, @@ -172,10 +377,11 @@ static void BM_WebGpuMatMulNBitsDecode(benchmark::State& state) { } std::vector model_data = SerializeMatMulNBitsModel(config); + const SelectedWebGpuContext& selected_context = GetSelectedWebGpuContext(); Ort::SessionOptions session_options; session_options.DisableMemPattern(); - session_options.AppendExecutionProvider("WebGPU", std::unordered_map{}); + session_options.AppendExecutionProvider("WebGPU", selected_context.provider_options); OrtSession* raw_session = nullptr; OrtStatus* status = g_ort->CreateSessionFromArray(env, model_data.data(), model_data.size(), session_options, &raw_session); @@ -217,7 +423,7 @@ static void BM_WebGpuMatMulNBitsDecode(benchmark::State& state) { const double total_flops = 2.0 * static_cast(config.n) * static_cast(config.k); - state.SetLabel("fp16_decode_bias"); + state.SetLabel("fp16_decode_bias_custom_adapter"); state.counters["TFLOPS"] = benchmark::Counter( total_flops, benchmark::Counter::kIsIterationInvariantRate); @@ -234,4 +440,4 @@ BENCHMARK(BM_WebGpuMatMulNBitsDecode) ->UseRealTime() ->Unit(benchmark::TimeUnit::kMicrosecond); -} // namespace \ No newline at end of file +} // namespace From ee09d8e4f1d70b6900ce66ee8429bb77db5c7bb5 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 13 Apr 2026 11:31:16 -0700 Subject: [PATCH 03/26] Stage --- .../webgpu/quantization/matmul_nbits.cc | 520 ++++++++++++++++-- .../core/providers/webgpu/compute_context.h | 6 + .../core/providers/webgpu/webgpu_context.cc | 10 + .../core/providers/webgpu/webgpu_context.h | 1 + onnxruntime/core/session/ort_version_check.h | 18 +- onnxruntime/test/onnx/microbenchmark/main.cc | 2 +- .../webgpu_matmul_nbits_decode.cc | 463 +++++++++++++--- 7 files changed, 906 insertions(+), 114 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index e0f87b4b6a6dd..97bf624d1e91d 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -3,12 +3,20 @@ #include #include +#include +#include +#include +#include +#include #include "contrib_ops/webgpu/quantization/matmul_nbits.h" #include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" #include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" #include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "core/common/inlined_containers.h" +#include "core/common/logging/macros.h" +#include "core/platform/env.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" @@ -20,6 +28,427 @@ namespace webgpu { namespace { constexpr unsigned int kMinMForTileOptimization = 4; +constexpr uint32_t kMatMulNBitsMinNForAutoTuning = 65536; +constexpr const char* kMatMulNBitsAutoTunerEnvVar = "ORT_WEBGPU_MATMUL_NBITS_ENABLE_AUTO_TUNER"; +constexpr double kMatMulNBitsTuneMinImprovementRatio = 0.10; + +struct MatMulNBitsTuneParams { + uint32_t tile_size_k_vec; + uint32_t workgroup_size; + uint32_t tile_size; +}; + +std::mutex g_matmul_nbits_tune_mutex; +InlinedHashMap g_matmul_nbits_tuned_params; + +constexpr std::array kMatMulNBitsTuneCandidates{{ + {8u, 64u, 8u}, + {8u, 128u, 8u}, + {8u, 64u, 16u}, + {8u, 128u, 16u}, + {16u, 64u, 8u}, + {16u, 128u, 8u}, + {16u, 64u, 16u}, + {16u, 128u, 16u}, + {32u, 64u, 8u}, + {32u, 128u, 8u}, + {32u, 64u, 16u}, + {32u, 128u, 16u}, +}}; + +constexpr int kMatMulNBitsTuneWarmupRuns = 1; +constexpr int kMatMulNBitsTuneMeasuredRuns = 2; + +std::string_view ToStdStringView(wgpu::StringView value) { + return std::string_view{value.data ? value.data : "", value.length}; +} + +std::string MakeMatMulNBitsTuneKey(const onnxruntime::webgpu::ComputeContext& context, + uint32_t M, + uint32_t N, + uint32_t K, + uint32_t block_size, + uint32_t nbits, + bool single_scale_weights, + bool is_fp16) { + return MakeStringWithClassicLocale(ToStdStringView(context.AdapterInfo().vendor), + "|", + ToStdStringView(context.AdapterInfo().architecture), + "|", + ToStdStringView(context.AdapterInfo().device), + "|M=", + M, + "|N=", + N, + "|K=", + K, + "|block=", + block_size, + "|bits=", + nbits, + "|single_scale=", + single_scale_weights ? 1 : 0, + "|fp16=", + is_fp16 ? 1 : 0); +} + +bool IsValidTuneCandidate(const onnxruntime::webgpu::ComputeContext& context, + const MatMulNBitsTuneParams& candidate) { + if (candidate.workgroup_size % candidate.tile_size_k_vec != 0 || + candidate.workgroup_size > context.DeviceLimits().maxComputeInvocationsPerWorkgroup) { + return false; + } + + const uint32_t sub_tile_count = candidate.workgroup_size / candidate.tile_size_k_vec; + return sub_tile_count > 0 && candidate.tile_size % sub_tile_count == 0; +} + +MatMulNBitsTuneParams GetDefaultMatMulNBitsTuneParams(const onnxruntime::webgpu::ComputeContext& context) { + return MatMulNBitsTuneParams{ + (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u, + 128u, + 8u, + }; +} + +bool ShouldTuneDefaultMatMulNBitsProgram(const onnxruntime::webgpu::ComputeContext& context, + uint32_t batch_count, + uint32_t M, + uint32_t N, + bool has_zero_points, + bool has_bias, + bool has_weight_idx, + bool has_weight_idx_indirect) { + std::string auto_tuner_env = Env::Default().GetEnvironmentVar(kMatMulNBitsAutoTunerEnvVar); + if (auto_tuner_env.empty()) { + return false; + } + + std::transform(auto_tuner_env.begin(), auto_tuner_env.end(), auto_tuner_env.begin(), + [](unsigned char value) { return static_cast(std::tolower(value)); }); + if (auto_tuner_env == "0" || auto_tuner_env == "false" || auto_tuner_env == "off") { + return false; + } + + return !context.IsGraphCaptureEnabled() && + batch_count == 1 && + M == 1 && + N >= kMatMulNBitsMinNForAutoTuning && + !has_zero_points && + !has_bias && + !has_weight_idx && + !has_weight_idx_indirect; +} + +Status RunDefaultMatMulNBitsProgram(const Tensor* a, + const Tensor* b, + const Tensor* scales, + const Tensor* zero_points, + const Tensor* bias, + uint32_t batch_count, + uint32_t M, + uint32_t N, + uint32_t K, + uint32_t block_size, + uint32_t n_blocks_per_col, + uint32_t zero_blocks_per_col, + uint32_t blob_size, + uint32_t nbits, + bool has_zero_points, + bool has_bias, + bool has_weight_idx, + bool has_weight_idx_indirect, + bool single_scale_weights, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + uint32_t weight_index, + const Tensor* weight_index_indirect, + const MatMulNBitsTuneParams& params, + bool wait_for_completion = false) { + constexpr uint32_t kU32Components = 4; + const uint32_t components_a = GetMaxComponents(K); + const uint32_t blob_size_in_words = blob_size / 4; + const uint32_t components_b = GetMaxComponents(blob_size_in_words); + const uint32_t components_b_with_u32 = components_b * kU32Components; + const uint32_t num_N_tile = CeilDiv(N, params.tile_size); + const uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; + + MatMulNBitsProgram program{params.tile_size, + nbits, + has_zero_points, + has_bias, + has_weight_idx, + has_weight_idx_indirect, + single_scale_weights, + params.tile_size_k_vec}; + program.SetWorkgroupSize(params.workgroup_size); + program.SetDispatchGroupSize(num_N_tile, M, batch_count); + program + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank}) + .AddUniformVariables({{M}, + {N}, + {K}, + {K / components_a}, + {K_of_b}, + {block_size}, + {n_blocks_per_col}, + {zero_blocks_per_col}, + {num_N_tile}, + {batch_count}, + {weight_index}}) + .CacheHint(nbits, has_zero_points, single_scale_weights, has_bias, has_weight_idx, has_weight_idx_indirect, params.tile_size_k_vec); + if (has_zero_points) { + program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); + } + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::None}); + } + if (has_weight_idx_indirect) { + program.AddInput({weight_index_indirect, ProgramTensorMetadataDependency::None}); + } + + ORT_RETURN_IF_ERROR(context.RunProgram(program)); + if (wait_for_completion) { + return context.FlushAndWait(); + } + return Status::OK(); +} + +MatMulNBitsTuneParams GetTunedMatMulNBitsParams(const Tensor* a, + const Tensor* b, + const Tensor* scales, + const Tensor* zero_points, + const Tensor* bias, + uint32_t batch_count, + uint32_t M, + uint32_t N, + uint32_t K, + uint32_t block_size, + uint32_t n_blocks_per_col, + uint32_t zero_blocks_per_col, + uint32_t blob_size, + uint32_t nbits, + bool has_zero_points, + bool has_bias, + bool has_weight_idx, + bool has_weight_idx_indirect, + bool single_scale_weights, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + uint32_t weight_index, + const Tensor* weight_index_indirect) { + const bool is_fp16 = y->DataType() == DataTypeImpl::GetType(); + const std::string tune_key = MakeMatMulNBitsTuneKey(context, M, N, K, block_size, nbits, single_scale_weights, is_fp16); + + { + std::lock_guard lock(g_matmul_nbits_tune_mutex); + const auto it = g_matmul_nbits_tuned_params.find(tune_key); + if (it != g_matmul_nbits_tuned_params.end()) { + return it->second; + } + } + + const MatMulNBitsTuneParams default_params = GetDefaultMatMulNBitsTuneParams(context); + MatMulNBitsTuneParams best_params = default_params; + double default_seconds = 0.0; + double best_seconds = 0.0; + + bool default_failed = false; + for (int i = 0; i < kMatMulNBitsTuneWarmupRuns; ++i) { + const Status status = RunDefaultMatMulNBitsProgram(a, + b, + scales, + zero_points, + bias, + batch_count, + M, + N, + K, + block_size, + n_blocks_per_col, + zero_blocks_per_col, + blob_size, + nbits, + has_zero_points, + has_bias, + has_weight_idx, + has_weight_idx_indirect, + single_scale_weights, + context, + y, + weight_index, + weight_index_indirect, + default_params, + true); + if (!status.IsOK()) { + default_failed = true; + break; + } + } + + if (!default_failed) { + for (int i = 0; i < kMatMulNBitsTuneMeasuredRuns; ++i) { + const auto start = std::chrono::steady_clock::now(); + const Status status = RunDefaultMatMulNBitsProgram(a, + b, + scales, + zero_points, + bias, + batch_count, + M, + N, + K, + block_size, + n_blocks_per_col, + zero_blocks_per_col, + blob_size, + nbits, + has_zero_points, + has_bias, + has_weight_idx, + has_weight_idx_indirect, + single_scale_weights, + context, + y, + weight_index, + weight_index_indirect, + default_params, + true); + const auto end = std::chrono::steady_clock::now(); + if (!status.IsOK()) { + default_failed = true; + break; + } + default_seconds += std::chrono::duration(end - start).count(); + } + } + + if (default_failed) { + LOGS(context.Logger(), WARNING) << "MatMulNBits tuner kept default params for " << tune_key + << " because the baseline measurement failed" + << ": tile_size_k_vec=" << default_params.tile_size_k_vec + << ", workgroup_size=" << default_params.workgroup_size + << ", tile_size=" << default_params.tile_size; + + std::lock_guard lock(g_matmul_nbits_tune_mutex); + g_matmul_nbits_tuned_params.insert_or_assign(tune_key, default_params); + return default_params; + } + + best_seconds = default_seconds; + + for (const auto& candidate : kMatMulNBitsTuneCandidates) { + if (!IsValidTuneCandidate(context, candidate)) { + continue; + } + + if (candidate.tile_size_k_vec == default_params.tile_size_k_vec && + candidate.workgroup_size == default_params.workgroup_size && + candidate.tile_size == default_params.tile_size) { + continue; + } + + bool candidate_failed = false; + for (int i = 0; i < kMatMulNBitsTuneWarmupRuns; ++i) { + const Status status = RunDefaultMatMulNBitsProgram(a, + b, + scales, + zero_points, + bias, + batch_count, + M, + N, + K, + block_size, + n_blocks_per_col, + zero_blocks_per_col, + blob_size, + nbits, + has_zero_points, + has_bias, + has_weight_idx, + has_weight_idx_indirect, + single_scale_weights, + context, + y, + weight_index, + weight_index_indirect, + candidate, + true); + if (!status.IsOK()) { + candidate_failed = true; + break; + } + } + if (candidate_failed) { + continue; + } + + double candidate_seconds = 0.0; + for (int i = 0; i < kMatMulNBitsTuneMeasuredRuns; ++i) { + const auto start = std::chrono::steady_clock::now(); + const Status status = RunDefaultMatMulNBitsProgram(a, + b, + scales, + zero_points, + bias, + batch_count, + M, + N, + K, + block_size, + n_blocks_per_col, + zero_blocks_per_col, + blob_size, + nbits, + has_zero_points, + has_bias, + has_weight_idx, + has_weight_idx_indirect, + single_scale_weights, + context, + y, + weight_index, + weight_index_indirect, + candidate, + true); + const auto end = std::chrono::steady_clock::now(); + if (!status.IsOK()) { + candidate_failed = true; + break; + } + candidate_seconds += std::chrono::duration(end - start).count(); + } + + if (!candidate_failed && candidate_seconds < best_seconds) { + best_seconds = candidate_seconds; + best_params = candidate; + } + } + + const double improvement_ratio = default_seconds > 0.0 + ? (default_seconds - best_seconds) / default_seconds + : 0.0; + if (improvement_ratio <= kMatMulNBitsTuneMinImprovementRatio) { + best_params = default_params; + best_seconds = default_seconds; + } + + LOGS(context.Logger(), WARNING) << "MatMulNBits tuner selected params for " << tune_key + << ": tile_size_k_vec=" << best_params.tile_size_k_vec + << ", workgroup_size=" << best_params.workgroup_size + << ", tile_size=" << best_params.tile_size + << ", default_measured_seconds=" << default_seconds + << ", selected_measured_seconds=" << best_seconds + << ", improvement_ratio=" << improvement_ratio; + + std::lock_guard lock(g_matmul_nbits_tune_mutex); + g_matmul_nbits_tuned_params.insert_or_assign(tune_key, best_params); + return best_params; +} } // namespace ONNX_OPERATOR_KERNEL_EX( @@ -301,46 +730,59 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, return context.RunProgram(program); } - // Use tile_size_k_vec=32 by default for better K-dimension parallelism. - // Intel devices use 16 as they have different subgroup/cache characteristics. - const uint32_t tile_size_k_vec = (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u; - - constexpr uint32_t workgroup_size = 128; - constexpr uint32_t tile_size = 8; - constexpr uint32_t kU32Components = 4; - uint32_t components_b_with_u32 = components_b * kU32Components; - uint32_t num_N_tile = (N + tile_size - 1) / tile_size; - uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; - MatMulNBitsProgram program{tile_size, static_cast(nbits), has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect, single_scale_weights, tile_size_k_vec}; - program.SetWorkgroupSize(workgroup_size); - program.SetDispatchGroupSize((N + tile_size - 1) / tile_size, M, batch_count); - program - .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, - {scales, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank}) - .AddUniformVariables({{M}, - {N}, - {K}, - {K / components_a}, - {K_of_b}, - {block_size}, - {n_blocks_per_col}, - {zero_blocks_per_col}, - {num_N_tile}, - {batch_count}, - {weight_index}}) - .CacheHint(nbits, has_zero_points, single_scale_weights, has_bias, has_weight_idx, has_weight_idx_indirect, tile_size_k_vec); - if (has_zero_points) { - program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); + MatMulNBitsTuneParams params{}; + if (ShouldTuneDefaultMatMulNBitsProgram(context, batch_count, M, N, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect)) { + params = GetTunedMatMulNBitsParams(a, + b, + scales, + zero_points, + bias, + batch_count, + M, + N, + K, + block_size, + n_blocks_per_col, + zero_blocks_per_col, + blob_size, + static_cast(nbits), + has_zero_points, + has_bias, + has_weight_idx, + has_weight_idx_indirect, + single_scale_weights, + context, + y, + weight_index, + weight_index_indirect); + } else { + params = GetDefaultMatMulNBitsTuneParams(context); } - if (has_bias) { - program.AddInput({bias, ProgramTensorMetadataDependency::None}); - } - if (has_weight_idx_indirect) { - program.AddInput({weight_index_indirect, ProgramTensorMetadataDependency::None}); - } - return context.RunProgram(program); + + return RunDefaultMatMulNBitsProgram(a, + b, + scales, + zero_points, + bias, + batch_count, + M, + N, + K, + block_size, + n_blocks_per_col, + zero_blocks_per_col, + blob_size, + static_cast(nbits), + has_zero_points, + has_bias, + has_weight_idx, + has_weight_idx_indirect, + single_scale_weights, + context, + y, + weight_index, + weight_index_indirect, + params); } } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 38848e98509ba..a42b8be1dcbcd 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -38,6 +38,7 @@ class ComputeContextBase { // This ensures no access to BufferManager from other classes, avoiding // potential misuse. friend class WebGpuContext; + friend class ComputeContextBase; private: static const webgpu::BufferManager& Get(const ComputeContextBase& context); @@ -121,6 +122,11 @@ class ComputeContextBase { return webgpu_context_.Run(*this, program); } + inline Status FlushAndWait() { + webgpu_context_.Flush(BufferManagerAccessor::Get(*this)); + return webgpu_context_.WaitForQueueIdle(); + } + protected: WebGpuContext& webgpu_context_; const WebGpuExecutionProvider& ep_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index ada9a2e8ab692..58e71de1fa211 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -184,6 +184,16 @@ Status WebGpuContext::Wait(wgpu::Future f) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status)); } +Status WebGpuContext::WaitForQueueIdle() { + return Wait(device_queue_.OnSubmittedWorkDone( + wgpu::CallbackMode::WaitAnyOnly, + [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + ORT_ENFORCE(status == wgpu::QueueWorkDoneStatus::Success, + "Failed to wait for submitted WebGPU work: ", + std::string_view{message}); + })); +} + Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& program) { const auto& inputs = program.Inputs(); const auto& outputs = program.Outputs(); diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 021c7f383a6d7..86c40c3b93750 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -165,6 +165,7 @@ class WebGpuContextFactory { class WebGpuContext final { public: Status Wait(wgpu::Future f); + Status WaitForQueueIdle(); const wgpu::Device& Device() const { return device_; } diff --git a/onnxruntime/core/session/ort_version_check.h b/onnxruntime/core/session/ort_version_check.h index 82fd757e3ce9f..dea0d4366fbe2 100644 --- a/onnxruntime/core/session/ort_version_check.h +++ b/onnxruntime/core/session/ort_version_check.h @@ -10,21 +10,27 @@ namespace onnxruntime::version_check { +#if defined(__cpp_consteval) && __cpp_consteval >= 201811L +#define ORT_VERSION_CHECK_CONSTEVAL consteval +#else +#define ORT_VERSION_CHECK_CONSTEVAL constexpr +#endif + // A simple consteval-friendly result type for ParseUint. // std::optional triggers an internal compiler error in MSVC 14.44 when used with consteval. struct ParseUintResult { uint32_t value; bool has_value; - consteval bool operator==(uint32_t other) const { return has_value && value == other; } - consteval bool operator!=(uint32_t other) const { return !(*this == other); } + ORT_VERSION_CHECK_CONSTEVAL bool operator==(uint32_t other) const { return has_value && value == other; } + ORT_VERSION_CHECK_CONSTEVAL bool operator!=(uint32_t other) const { return !(*this == other); } }; -inline consteval ParseUintResult ParseUintNone() { return {0, false}; } +inline ORT_VERSION_CHECK_CONSTEVAL ParseUintResult ParseUintNone() { return {0, false}; } // Parse a non-negative integer from a string_view without leading zeros. // Returns a result with has_value == false on failure (empty, leading zero, non-digit, or overflow). -consteval ParseUintResult ParseUint(std::string_view str) { +ORT_VERSION_CHECK_CONSTEVAL ParseUintResult ParseUint(std::string_view str) { if (str.empty()) return ParseUintNone(); // Leading zeros are not allowed (except "0" itself). if (str.size() > 1 && str[0] == '0') return ParseUintNone(); @@ -42,7 +48,7 @@ consteval ParseUintResult ParseUint(std::string_view str) { // - Major version is 1 // - Y and Z are non-negative integers without leading zeros // - Y (minor version) must equal expected_api_version (defaults to ORT_API_VERSION) -consteval bool IsOrtVersionValid(std::string_view version, uint32_t expected_api_version = ORT_API_VERSION) { +ORT_VERSION_CHECK_CONSTEVAL bool IsOrtVersionValid(std::string_view version, uint32_t expected_api_version = ORT_API_VERSION) { size_t first_dot = version.find('.'); if (first_dot == std::string_view::npos) return false; size_t second_dot = version.find('.', first_dot + 1); @@ -65,4 +71,6 @@ consteval bool IsOrtVersionValid(std::string_view version, uint32_t expected_api return true; } +#undef ORT_VERSION_CHECK_CONSTEVAL + } // namespace onnxruntime::version_check diff --git a/onnxruntime/test/onnx/microbenchmark/main.cc b/onnxruntime/test/onnx/microbenchmark/main.cc index b356dda740a31..a2cb6aaff281a 100644 --- a/onnxruntime/test/onnx/microbenchmark/main.cc +++ b/onnxruntime/test/onnx/microbenchmark/main.cc @@ -79,7 +79,7 @@ int main(int argc, char** argv) { ::benchmark::Initialize(&argc, argv); if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return -1; - ORT_ABORT_ON_ERROR(g_ort->CreateEnv(ORT_LOGGING_LEVEL_ERROR, "test", &env)); + ORT_ABORT_ON_ERROR(g_ort->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env)); ::benchmark::RunSpecifiedBenchmarks(); g_ort->ReleaseEnv(env); return 0; diff --git a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc index fd2e46c6bbbf4..9cfda7033e8d2 100644 --- a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc +++ b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc @@ -3,7 +3,12 @@ #include +#include +#include +#include +#include #include +#include #include #include #include @@ -11,9 +16,11 @@ #include #include +#include #include #include +#include #include #include @@ -21,6 +28,23 @@ extern OrtEnv* env; extern const OrtApi* g_ort; namespace { +constexpr const char* kMatMulNBitsAutoTunerEnvVar = "ORT_WEBGPU_MATMUL_NBITS_ENABLE_AUTO_TUNER"; +constexpr const char* kDecodeBenchmarkModeEnvVar = "ORT_WEBGPU_MATMUL_NBITS_BENCHMARK_MODE"; +constexpr const char* kDecodeBenchmarkGpuEnvVar = "ORT_WEBGPU_MATMUL_NBITS_BENCHMARK_GPU"; +constexpr float kDecodeCorrectnessAbsTolerance = 0.1f; +constexpr float kDecodeCorrectnessRelTolerance = 0.01f; + +enum class DecodeBenchmarkMode { + kPerf, + kCorrectness, +}; + +enum class DecodeBenchmarkGpu { + kRtx5060Ti, + kT1000, +}; + +bool IsMatMulNBitsAutoTunerEnabled(); struct DecodeBenchConfig { int64_t n; @@ -32,11 +56,16 @@ struct DecodeBenchConfig { struct AdapterSelectionConfig { // adapter_type: Dawn adapter type to select, e.g. integrated or discrete GPU. - // adapter_index: zero-based index among only adapters of adapter_type. + // preferred_vendor_id/device_id: stable PCI identifiers used to locate the target GPU regardless of enumeration order. + // preferred_device_substring: fallback name match if device IDs are unavailable or change. + // adapter_index: fallback zero-based index among only adapters of adapter_type if the preferred adapter is not found. // context_id: ORT WebGPU custom context ID used to bind the externally created instance/device. // backend_type: Dawn backend to enumerate adapters from, e.g. D3D12 or Vulkan. // print_adapter_list: whether to print all discovered adapters before selecting one. WGPUAdapterType adapter_type; + uint32_t preferred_vendor_id; + uint32_t preferred_device_id; + const char* preferred_device_substring; int adapter_index; int context_id; WGPUBackendType backend_type; @@ -64,6 +93,116 @@ struct SelectedWebGpuContext { std::string selected_adapter_summary; }; +struct DecodeTrafficStats { + double input_bytes; + double packed_weight_bytes; + double scale_bytes; + double output_bytes; + double total_bytes; +}; + +constexpr double kRtx5060TiTheoreticalBandwidthBytesPerSecond = 448.0 * 1000.0 * 1000.0 * 1000.0; +constexpr int kDecodeWarmupRuns = 25; + +DecodeBenchmarkMode GetDecodeBenchmarkMode() { + std::string mode_env = onnxruntime::Env::Default().GetEnvironmentVar(kDecodeBenchmarkModeEnvVar); + if (mode_env.empty()) { + return DecodeBenchmarkMode::kPerf; + } + + std::transform(mode_env.begin(), mode_env.end(), mode_env.begin(), + [](unsigned char value) { return static_cast(std::tolower(value)); }); + if (mode_env == "0" || mode_env == "false" || mode_env == "off" || + mode_env == "check" || mode_env == "correctness" || mode_env == "validate") { + return DecodeBenchmarkMode::kCorrectness; + } + + return DecodeBenchmarkMode::kPerf; +} + +bool IsDecodeBenchmarkPerfMode() { + return GetDecodeBenchmarkMode() == DecodeBenchmarkMode::kPerf; +} + +DecodeBenchmarkGpu GetDecodeBenchmarkGpu() { + std::string gpu_env = onnxruntime::Env::Default().GetEnvironmentVar(kDecodeBenchmarkGpuEnvVar); + if (gpu_env.empty()) { + return DecodeBenchmarkGpu::kRtx5060Ti; + } + + std::transform(gpu_env.begin(), gpu_env.end(), gpu_env.begin(), + [](unsigned char value) { return static_cast(std::tolower(value)); }); + if (gpu_env == "t" || gpu_env == "t1000") { + return DecodeBenchmarkGpu::kT1000; + } + + return DecodeBenchmarkGpu::kRtx5060Ti; +} + +std::string GetDecodeBenchmarkLabel() { + const char* mode_label = IsDecodeBenchmarkPerfMode() ? "perf" : "correctness"; + const char* adapter_label = GetDecodeBenchmarkGpu() == DecodeBenchmarkGpu::kRtx5060Ti ? "rtx" : "t"; + const char* tuner_label = IsMatMulNBitsAutoTunerEnabled() ? "tuner_on" : "tuner_off"; + + std::ostringstream stream; + stream << "fp16_decode_" << mode_label << '_' << adapter_label << '_' << tuner_label; + return stream.str(); +} + +bool IsMatMulNBitsAutoTunerEnabled() { + std::string auto_tuner_env = onnxruntime::Env::Default().GetEnvironmentVar(kMatMulNBitsAutoTunerEnvVar); + if (auto_tuner_env.empty()) { + return false; + } + + std::transform(auto_tuner_env.begin(), auto_tuner_env.end(), auto_tuner_env.begin(), + [](unsigned char value) { return static_cast(std::tolower(value)); }); + return auto_tuner_env != "0" && auto_tuner_env != "false" && auto_tuner_env != "off"; +} + +std::vector GetRequiredDeviceFeatures(const wgpu::Adapter& adapter) { + std::vector required_features; + constexpr wgpu::FeatureName features[]{ +#if !defined(__wasm__) + wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses, + wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix, +#endif + wgpu::FeatureName::TimestampQuery, + wgpu::FeatureName::ShaderF16, + wgpu::FeatureName::Subgroups, +#if !defined(__wasm__) + wgpu::FeatureName::BufferMapExtendedUsages, +#endif + }; + for (auto feature : features) { + if (adapter.HasFeature(feature)) { + required_features.push_back(feature); + } + } + return required_features; +} + +wgpu::Limits GetRequiredDeviceLimits(const wgpu::Adapter& adapter) { + wgpu::Limits required_limits{}; + wgpu::Limits adapter_limits{}; + if (!adapter.GetLimits(&adapter_limits)) { + throw std::runtime_error("Failed to query adapter limits for the selected WebGPU adapter."); + } + + required_limits.maxBindGroups = adapter_limits.maxBindGroups; + required_limits.maxComputeWorkgroupStorageSize = adapter_limits.maxComputeWorkgroupStorageSize; + required_limits.maxComputeWorkgroupsPerDimension = adapter_limits.maxComputeWorkgroupsPerDimension; + required_limits.maxStorageBuffersPerShaderStage = adapter_limits.maxStorageBuffersPerShaderStage; + required_limits.maxStorageBufferBindingSize = adapter_limits.maxStorageBufferBindingSize; + required_limits.maxBufferSize = adapter_limits.maxBufferSize; + required_limits.maxComputeInvocationsPerWorkgroup = adapter_limits.maxComputeInvocationsPerWorkgroup; + required_limits.maxComputeWorkgroupSizeX = adapter_limits.maxComputeWorkgroupSizeX; + required_limits.maxComputeWorkgroupSizeY = adapter_limits.maxComputeWorkgroupSizeY; + required_limits.maxComputeWorkgroupSizeZ = adapter_limits.maxComputeWorkgroupSizeZ; + + return required_limits; +} + std::string ToString(WGPUStringView value) { return value.data == nullptr ? std::string{} : std::string(value.data, value.length); } @@ -100,12 +239,60 @@ std::string FormatAdapterSummary(const AdapterCandidate& adapter) { return stream.str(); } +std::string FormatFeatureSupport(const dawn::native::Adapter& adapter) { + const wgpu::Adapter wgpu_adapter = adapter.Get(); + std::ostringstream stream; + stream << "shader_f16=" << (wgpu_adapter.HasFeature(wgpu::FeatureName::ShaderF16) ? "yes" : "no") + << " subgroups=" << (wgpu_adapter.HasFeature(wgpu::FeatureName::Subgroups) ? "yes" : "no") + << " timestamp_query=" << (wgpu_adapter.HasFeature(wgpu::FeatureName::TimestampQuery) ? "yes" : "no"); +#if !defined(__wasm__) + stream << " subgroup_matrix=" << (wgpu_adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix) ? "yes" : "no") + << " buffer_map_extended_usages=" << (wgpu_adapter.HasFeature(wgpu::FeatureName::BufferMapExtendedUsages) ? "yes" : "no") + << " timestamp_query_inside_passes=" << (wgpu_adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses) ? "yes" : "no"); +#endif + return stream.str(); +} + +DecodeTrafficStats CalculateDecodeTrafficStats(const DecodeBenchConfig& config) { + const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; + const int64_t blob_size = (config.block_size * config.bits) / 8; + + const double input_bytes = static_cast(config.k) * sizeof(Ort::Float16_t); + const double packed_weight_bytes = static_cast(config.n) * static_cast(k_blocks) * static_cast(blob_size); + const double scale_bytes = static_cast(config.n) * static_cast(k_blocks) * sizeof(Ort::Float16_t); + const double output_bytes = static_cast(config.n) * sizeof(Ort::Float16_t); + + return { + input_bytes, + packed_weight_bytes, + scale_bytes, + output_bytes, + input_bytes + packed_weight_bytes + scale_bytes + output_bytes, + }; +} + AdapterSelectionConfig GetAdapterSelectionConfig() { - // Pick the second discrete adapter by default so this benchmark can target the - // "other" dGPU on a machine with two discrete GPUs and one integrated GPU. + if (GetDecodeBenchmarkGpu() == DecodeBenchmarkGpu::kT1000) { + return { + WGPUAdapterType_DiscreteGPU, // adapter_type + 4318, // preferred_vendor_id (NVIDIA) + 8112, // preferred_device_id (T1000) + "T1000", // preferred_device_substring + 0, // adapter_index fallback + 0, // context_id + WGPUBackendType_D3D12, // backend_type + true, // print_adapter_list + }; + } + + // Prefer the RTX 5060 Ti by stable PCI identity so selection does not depend on + // Dawn enumeration order. Fall back to the historical second discrete adapter. return { WGPUAdapterType_DiscreteGPU, // adapter_type - 1, // adapter_index + 4318, // preferred_vendor_id (NVIDIA) + 11524, // preferred_device_id (RTX 5060 Ti) + "RTX 5060 Ti", // preferred_device_substring + 1, // adapter_index fallback 1, // context_id WGPUBackendType_D3D12, // backend_type true, // print_adapter_list @@ -115,8 +302,20 @@ AdapterSelectionConfig GetAdapterSelectionConfig() { SelectedWebGpuContext CreateSelectedWebGpuContext() { const AdapterSelectionConfig config = GetAdapterSelectionConfig(); + wgpu::InstanceFeatureName required_instance_features[] = {wgpu::InstanceFeatureName::TimedWaitAny}; + wgpu::InstanceDescriptor instance_desc{}; + instance_desc.requiredFeatures = required_instance_features; + instance_desc.requiredFeatureCount = sizeof(required_instance_features) / sizeof(required_instance_features[0]); + SelectedWebGpuContext selected_context; - selected_context.dawn_instance = std::make_unique(); + selected_context.dawn_instance = std::make_unique(&instance_desc); + +#if !defined(BUILD_DAWN_SHARED_LIBRARY) + static std::once_flag dawn_procs_initialized; + std::call_once(dawn_procs_initialized, []() { + dawnProcSetProcs(&dawn::native::GetProcs()); + }); +#endif WGPURequestAdapterOptions adapter_options = WGPU_REQUEST_ADAPTER_OPTIONS_INIT; adapter_options.backendType = config.backend_type; @@ -180,7 +379,9 @@ SelectedWebGpuContext CreateSelectedWebGpuContext() { } printed_gpu = true; - std::cout << " " << FormatAdapterSummary(candidate) << std::endl; + std::cout << " " << FormatAdapterSummary(candidate) + << " features={" << FormatFeatureSupport(candidate.adapter) << "}" + << std::endl; } if (!printed_gpu) { @@ -188,32 +389,68 @@ SelectedWebGpuContext CreateSelectedWebGpuContext() { } } - const AdapterCandidate* selected_adapter = nullptr; - for (const auto& candidate : candidates) { + AdapterCandidate* selected_adapter = nullptr; + for (auto& candidate : candidates) { + if (candidate.adapter_type == config.adapter_type && + candidate.vendor_id == config.preferred_vendor_id && + candidate.device_id == config.preferred_device_id) { + selected_adapter = &candidate; + break; + } + } + + if (selected_adapter == nullptr && config.preferred_device_substring != nullptr) { + for (auto& candidate : candidates) { + if (candidate.adapter_type == config.adapter_type && + candidate.device.find(config.preferred_device_substring) != std::string::npos) { + selected_adapter = &candidate; + break; + } + } + } + + if (selected_adapter == nullptr) { + for (auto& candidate : candidates) { if (candidate.adapter_type == config.adapter_type && candidate.type_index == config.adapter_index) { selected_adapter = &candidate; break; } + } } if (selected_adapter == nullptr) { std::ostringstream stream; - stream << "Failed to find " << AdapterTypeToString(config.adapter_type) - << " adapter index " << config.adapter_index + stream << "Failed to find preferred " << AdapterTypeToString(config.adapter_type) + << " adapter vendor_id=" << config.preferred_vendor_id + << " device_id=" << config.preferred_device_id + << " name~=" << (config.preferred_device_substring ? config.preferred_device_substring : "") + << ", or fallback adapter index " << config.adapter_index << ". Update GetAdapterSelectionConfig() to match the available adapters listed above."; throw std::runtime_error(stream.str()); } + const wgpu::Adapter adapter = selected_adapter->adapter.Get(); + std::vector required_features = GetRequiredDeviceFeatures(adapter); + wgpu::Limits required_limits = GetRequiredDeviceLimits(adapter); + wgpu::DeviceDescriptor device_desc{}; + if (!required_features.empty()) { + device_desc.requiredFeatures = required_features.data(); + device_desc.requiredFeatureCount = required_features.size(); + } + device_desc.requiredLimits = &required_limits; + selected_context.instance = selected_context.dawn_instance->Get(); - selected_context.device = selected_adapter->adapter.CreateDevice(); + selected_context.device = selected_adapter->adapter.CreateDevice(&device_desc); if (selected_context.device == nullptr) { throw std::runtime_error("Failed to create a WGPUDevice for the selected adapter."); } selected_context.selected_adapter_summary = FormatAdapterSummary(*selected_adapter); std::cout << "Selected Dawn adapter for WebGPU benchmark: " - << selected_context.selected_adapter_summary << std::endl; + << selected_context.selected_adapter_summary + << " features={" << FormatFeatureSupport(selected_adapter->adapter) << "}" + << std::endl; selected_context.provider_options["deviceId"] = std::to_string(config.context_id); selected_context.provider_options["webgpuInstance"] = std::to_string(reinterpret_cast(selected_context.instance)); @@ -248,10 +485,18 @@ void AddTensorInitializer(ONNX_NAMESPACE::GraphProto& graph, std::vector GetDecodeBenchConfigs() { // Each entry is {N, K, bits, block_size, accuracy_level} for a decode-style M=1 run. return { - {5120, 3072, 4, 32, 4}, - {8192, 3072, 4, 32, 4}, - {3072, 8192, 4, 32, 4}, - {200064, 3072, 4, 32, 4}, + // QKV + AttnProj + {1024, 2048, 4, 32, 4}, + {2048, 2048, 4, 32, 4}, + + // Gate + Up proj + {6144, 2048, 4, 32, 4}, + + // Down proj + {2048, 6144, 4, 32, 4}, + + // Vocab proj + {151936, 2048, 4, 32, 4}, }; } @@ -260,7 +505,6 @@ void AddMatMulNBitsNode(ONNX_NAMESPACE::GraphProto& graph, const std::string& input_name, const std::string& weight_name, const std::string& scale_name, - const std::string& bias_name, const std::string& output_name, int64_t k, int64_t n, @@ -276,7 +520,6 @@ void AddMatMulNBitsNode(ONNX_NAMESPACE::GraphProto& graph, node->add_input(scale_name); node->add_input(""); node->add_input(""); - node->add_input(bias_name); node->add_output(output_name); auto* attr_k = node->add_attribute(); @@ -336,21 +579,17 @@ std::vector SerializeMatMulNBitsModel(const DecodeBenchConfig& config) std::vector packed_b(static_cast(config.n * k_blocks * blob_size), uint8_t{0x11}); std::vector scales(static_cast(config.n * k_blocks), Ort::Float16_t(0.03125f)); - std::vector bias(static_cast(config.n), Ort::Float16_t(0.125f)); AddTensorInitializer(*graph, "B", ONNX_NAMESPACE::TensorProto_DataType_UINT8, {config.n, k_blocks, blob_size}, packed_b); AddTensorInitializer(*graph, "scales", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {config.n, k_blocks}, scales); - AddTensorInitializer(*graph, "bias", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - {config.n}, bias); AddMatMulNBitsNode(*graph, "MatMulNBitsDecode", "A", "B", "scales", - "bias", "Y", config.k, config.n, @@ -362,71 +601,155 @@ std::vector SerializeMatMulNBitsModel(const DecodeBenchConfig& config) return std::vector(serialized.begin(), serialized.end()); } -static void BM_WebGpuMatMulNBitsDecode(benchmark::State& state) { - const DecodeBenchConfig config{ - state.range(0), - state.range(1), - state.range(2), - state.range(3), - state.range(4), - }; - - if (config.k % config.block_size != 0) { - state.SkipWithError("K must be divisible by block_size for this benchmark skeleton."); - return; - } - - std::vector model_data = SerializeMatMulNBitsModel(config); - const SelectedWebGpuContext& selected_context = GetSelectedWebGpuContext(); - +Ort::Session CreateSessionFromModelData(const std::vector& model_data, + const std::unordered_map* provider_options) { Ort::SessionOptions session_options; session_options.DisableMemPattern(); - session_options.AppendExecutionProvider("WebGPU", selected_context.provider_options); + if (provider_options != nullptr) { + session_options.AppendExecutionProvider("WebGPU", *provider_options); + } OrtSession* raw_session = nullptr; OrtStatus* status = g_ort->CreateSessionFromArray(env, model_data.data(), model_data.size(), session_options, &raw_session); if (status != nullptr) { - state.SkipWithError(g_ort->GetErrorMessage(status)); + std::string error_message = g_ort->GetErrorMessage(status); g_ort->ReleaseStatus(status); - return; + throw std::runtime_error(error_message); } - Ort::Session session{raw_session}; - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); - std::vector input_shape{1, config.k}; - std::vector activation(static_cast(config.k)); + return Ort::Session{raw_session}; +} - std::mt19937 rng(123); - std::uniform_real_distribution dist(-1.0f, 1.0f); - for (auto& value : activation) { - value = Ort::Float16_t(dist(rng)); - } +void ValidateDecodeOutputs(const std::vector& model_data, + Ort::Session& webgpu_session, + const char* const* input_names, + const Ort::Value* input_tensor, + const char* const* output_names) { + Ort::Session cpu_session = CreateSessionFromModelData(model_data, nullptr); - const char* input_names[] = {"A"}; - const char* output_names[] = {"Y"}; + auto webgpu_outputs = webgpu_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensor, 1, output_names, 1); + auto cpu_outputs = cpu_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensor, 1, output_names, 1); - auto input_tensor = Ort::Value::CreateTensor(memory_info, - activation.data(), - activation.size(), - input_shape.data(), - input_shape.size()); + if (webgpu_outputs.size() != 1 || cpu_outputs.size() != 1) { + throw std::runtime_error("Expected a single output from both WebGPU and CPU sessions."); + } - for (int i = 0; i < 10; ++i) { - auto warmup_outputs = session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, 1); - benchmark::DoNotOptimize(warmup_outputs); + const auto& webgpu_output = webgpu_outputs[0]; + const auto& cpu_output = cpu_outputs[0]; + const size_t element_count = webgpu_output.GetTensorTypeAndShapeInfo().GetElementCount(); + if (element_count != cpu_output.GetTensorTypeAndShapeInfo().GetElementCount()) { + throw std::runtime_error("WebGPU and CPU output sizes do not match."); } - for (auto _ : state) { - auto outputs = session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, 1); - benchmark::DoNotOptimize(outputs); + const auto* webgpu_data = webgpu_output.GetTensorData(); + const auto* cpu_data = cpu_output.GetTensorData(); + float max_abs_diff = 0.0f; + size_t max_abs_diff_index = 0; + for (size_t i = 0; i < element_count; ++i) { + const float webgpu_value = webgpu_data[i].ToFloat(); + const float cpu_value = cpu_data[i].ToFloat(); + const float abs_diff = std::abs(webgpu_value - cpu_value); + const float allowed_diff = kDecodeCorrectnessAbsTolerance + + kDecodeCorrectnessRelTolerance * std::max(std::abs(webgpu_value), std::abs(cpu_value)); + if (abs_diff > max_abs_diff) { + max_abs_diff = abs_diff; + max_abs_diff_index = i; + } + if (abs_diff > allowed_diff) { + std::ostringstream stream; + stream << "Decode correctness check failed at index " << i + << ": webgpu=" << webgpu_value + << ", cpu=" << cpu_value + << ", abs_diff=" << abs_diff + << ", allowed_diff=" << allowed_diff; + throw std::runtime_error(stream.str()); + } } - const double total_flops = 2.0 * static_cast(config.n) * static_cast(config.k); + std::cout << "Decode correctness check passed. max_abs_diff=" << max_abs_diff + << " at index " << max_abs_diff_index << std::endl; +} - state.SetLabel("fp16_decode_bias_custom_adapter"); - state.counters["TFLOPS"] = benchmark::Counter( - total_flops, - benchmark::Counter::kIsIterationInvariantRate); +static void BM_WebGpuMatMulNBitsDecode(benchmark::State& state) { + try { + const DecodeBenchConfig config{ + state.range(0), + state.range(1), + state.range(2), + state.range(3), + state.range(4), + }; + + if (config.k % config.block_size != 0) { + state.SkipWithError("K must be divisible by block_size for this benchmark skeleton."); + return; + } + + const DecodeTrafficStats traffic = CalculateDecodeTrafficStats(config); + std::vector model_data = SerializeMatMulNBitsModel(config); + const SelectedWebGpuContext& selected_context = GetSelectedWebGpuContext(); + Ort::Session session = CreateSessionFromModelData(model_data, &selected_context.provider_options); + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + std::vector input_shape{1, config.k}; + std::vector activation(static_cast(config.k)); + + std::mt19937 rng(123); + std::uniform_real_distribution dist(-1.0f, 1.0f); + for (auto& value : activation) { + value = Ort::Float16_t(dist(rng)); + } + + const char* input_names[] = {"A"}; + const char* output_names[] = {"Y"}; + + auto input_tensor = Ort::Value::CreateTensor(memory_info, + activation.data(), + activation.size(), + input_shape.data(), + input_shape.size()); + + if (!IsDecodeBenchmarkPerfMode()) { + ValidateDecodeOutputs(model_data, session, input_names, &input_tensor, output_names); + } + + // Warm up shader compilation, allocations, and caches before measured iterations. + for (int i = 0; i < kDecodeWarmupRuns; ++i) { + auto warmup_outputs = session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, 1); + benchmark::DoNotOptimize(warmup_outputs); + } + + double total_kernel_seconds = 0.0; + for (auto _ : state) { + const auto kernel_start = std::chrono::steady_clock::now(); + auto outputs = session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, 1); + const auto kernel_end = std::chrono::steady_clock::now(); + total_kernel_seconds += std::chrono::duration(kernel_end - kernel_start).count(); + benchmark::DoNotOptimize(outputs); + } + + const double total_flops = 2.0 * static_cast(config.n) * static_cast(config.k); + const double achieved_bandwidth_bytes_per_second = + total_kernel_seconds > 0.0 + ? traffic.total_bytes * static_cast(state.iterations()) / total_kernel_seconds + : 0.0; + const double achieved_bandwidth_gbps = achieved_bandwidth_bytes_per_second / 1.0e9; + const double rtx_5060_ti_utilization_pct = + achieved_bandwidth_bytes_per_second / kRtx5060TiTheoreticalBandwidthBytesPerSecond * 100.0; + + state.SetLabel(GetDecodeBenchmarkLabel()); + state.counters["TFLOPS"] = benchmark::Counter( + total_flops, + benchmark::Counter::kIsIterationInvariantRate); + state.counters["MemBW_GBps"] = benchmark::Counter(achieved_bandwidth_gbps); + state.counters["BWUtil_5060Ti_pct"] = benchmark::Counter(rtx_5060_ti_utilization_pct); + state.counters["Traffic_MB"] = benchmark::Counter(traffic.total_bytes / 1.0e6); + state.counters["Input_MB"] = benchmark::Counter(traffic.input_bytes / 1.0e6); + state.counters["PackedW_MB"] = benchmark::Counter(traffic.packed_weight_bytes / 1.0e6); + state.counters["Scales_MB"] = benchmark::Counter(traffic.scale_bytes / 1.0e6); + state.counters["Output_MB"] = benchmark::Counter(traffic.output_bytes / 1.0e6); + } catch (const std::exception& ex) { + state.SkipWithError(ex.what()); + } } void ApplyWebGpuMatMulNBitsDecodeArgs(benchmark::internal::Benchmark* benchmark) { @@ -437,6 +760,8 @@ void ApplyWebGpuMatMulNBitsDecodeArgs(benchmark::internal::Benchmark* benchmark) BENCHMARK(BM_WebGpuMatMulNBitsDecode) ->Apply(ApplyWebGpuMatMulNBitsDecodeArgs) + ->Repetitions(5) + ->ReportAggregatesOnly() ->UseRealTime() ->Unit(benchmark::TimeUnit::kMicrosecond); From aa357eef415bab2055e924a1afbc8696a3a49a12 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Wed, 15 Apr 2026 12:05:13 -0700 Subject: [PATCH 04/26] More changes --- .../webgpu/quantization/matmul_nbits.cc | 613 ++++-------------- .../webgpu/quantization/matmul_nbits.h | 34 + .../matmul_nbits_m1.wgsl.template | 161 +++++ .../webgpu/quantization/matmul_nbits_silu.cc | 330 ++++++++++ .../webgpu/quantization/matmul_nbits_silu.h | 39 ++ .../matmul_nbits_silu_mul.wgsl.template | 169 +++++ .../webgpu/webgpu_contrib_kernels.cc | 2 + .../core/graph/contrib_ops/contrib_defs.cc | 61 ++ .../core/optimizer/graph_transformer_utils.cc | 15 + .../optimizer/matmul_nbits_silu_fusion.cc | 244 +++++++ .../core/optimizer/matmul_nbits_silu_fusion.h | 19 + .../webgpu_matmul_nbits_decode.cc | 421 +++++++++++- .../matmul_nbits_silu_fusion_test.cc | 143 ++++ 13 files changed, 1766 insertions(+), 485 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_m1.wgsl.template create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.h create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul.wgsl.template create mode 100644 onnxruntime/core/optimizer/matmul_nbits_silu_fusion.cc create mode 100644 onnxruntime/core/optimizer/matmul_nbits_silu_fusion.h create mode 100644 onnxruntime/test/optimizer/matmul_nbits_silu_fusion_test.cc diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 97bf624d1e91d..afdc49d67765b 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -3,20 +3,12 @@ #include #include -#include -#include -#include -#include -#include #include "contrib_ops/webgpu/quantization/matmul_nbits.h" #include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" #include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" #include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" -#include "core/common/inlined_containers.h" -#include "core/common/logging/macros.h" -#include "core/platform/env.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" @@ -28,427 +20,6 @@ namespace webgpu { namespace { constexpr unsigned int kMinMForTileOptimization = 4; -constexpr uint32_t kMatMulNBitsMinNForAutoTuning = 65536; -constexpr const char* kMatMulNBitsAutoTunerEnvVar = "ORT_WEBGPU_MATMUL_NBITS_ENABLE_AUTO_TUNER"; -constexpr double kMatMulNBitsTuneMinImprovementRatio = 0.10; - -struct MatMulNBitsTuneParams { - uint32_t tile_size_k_vec; - uint32_t workgroup_size; - uint32_t tile_size; -}; - -std::mutex g_matmul_nbits_tune_mutex; -InlinedHashMap g_matmul_nbits_tuned_params; - -constexpr std::array kMatMulNBitsTuneCandidates{{ - {8u, 64u, 8u}, - {8u, 128u, 8u}, - {8u, 64u, 16u}, - {8u, 128u, 16u}, - {16u, 64u, 8u}, - {16u, 128u, 8u}, - {16u, 64u, 16u}, - {16u, 128u, 16u}, - {32u, 64u, 8u}, - {32u, 128u, 8u}, - {32u, 64u, 16u}, - {32u, 128u, 16u}, -}}; - -constexpr int kMatMulNBitsTuneWarmupRuns = 1; -constexpr int kMatMulNBitsTuneMeasuredRuns = 2; - -std::string_view ToStdStringView(wgpu::StringView value) { - return std::string_view{value.data ? value.data : "", value.length}; -} - -std::string MakeMatMulNBitsTuneKey(const onnxruntime::webgpu::ComputeContext& context, - uint32_t M, - uint32_t N, - uint32_t K, - uint32_t block_size, - uint32_t nbits, - bool single_scale_weights, - bool is_fp16) { - return MakeStringWithClassicLocale(ToStdStringView(context.AdapterInfo().vendor), - "|", - ToStdStringView(context.AdapterInfo().architecture), - "|", - ToStdStringView(context.AdapterInfo().device), - "|M=", - M, - "|N=", - N, - "|K=", - K, - "|block=", - block_size, - "|bits=", - nbits, - "|single_scale=", - single_scale_weights ? 1 : 0, - "|fp16=", - is_fp16 ? 1 : 0); -} - -bool IsValidTuneCandidate(const onnxruntime::webgpu::ComputeContext& context, - const MatMulNBitsTuneParams& candidate) { - if (candidate.workgroup_size % candidate.tile_size_k_vec != 0 || - candidate.workgroup_size > context.DeviceLimits().maxComputeInvocationsPerWorkgroup) { - return false; - } - - const uint32_t sub_tile_count = candidate.workgroup_size / candidate.tile_size_k_vec; - return sub_tile_count > 0 && candidate.tile_size % sub_tile_count == 0; -} - -MatMulNBitsTuneParams GetDefaultMatMulNBitsTuneParams(const onnxruntime::webgpu::ComputeContext& context) { - return MatMulNBitsTuneParams{ - (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u, - 128u, - 8u, - }; -} - -bool ShouldTuneDefaultMatMulNBitsProgram(const onnxruntime::webgpu::ComputeContext& context, - uint32_t batch_count, - uint32_t M, - uint32_t N, - bool has_zero_points, - bool has_bias, - bool has_weight_idx, - bool has_weight_idx_indirect) { - std::string auto_tuner_env = Env::Default().GetEnvironmentVar(kMatMulNBitsAutoTunerEnvVar); - if (auto_tuner_env.empty()) { - return false; - } - - std::transform(auto_tuner_env.begin(), auto_tuner_env.end(), auto_tuner_env.begin(), - [](unsigned char value) { return static_cast(std::tolower(value)); }); - if (auto_tuner_env == "0" || auto_tuner_env == "false" || auto_tuner_env == "off") { - return false; - } - - return !context.IsGraphCaptureEnabled() && - batch_count == 1 && - M == 1 && - N >= kMatMulNBitsMinNForAutoTuning && - !has_zero_points && - !has_bias && - !has_weight_idx && - !has_weight_idx_indirect; -} - -Status RunDefaultMatMulNBitsProgram(const Tensor* a, - const Tensor* b, - const Tensor* scales, - const Tensor* zero_points, - const Tensor* bias, - uint32_t batch_count, - uint32_t M, - uint32_t N, - uint32_t K, - uint32_t block_size, - uint32_t n_blocks_per_col, - uint32_t zero_blocks_per_col, - uint32_t blob_size, - uint32_t nbits, - bool has_zero_points, - bool has_bias, - bool has_weight_idx, - bool has_weight_idx_indirect, - bool single_scale_weights, - onnxruntime::webgpu::ComputeContext& context, - Tensor* y, - uint32_t weight_index, - const Tensor* weight_index_indirect, - const MatMulNBitsTuneParams& params, - bool wait_for_completion = false) { - constexpr uint32_t kU32Components = 4; - const uint32_t components_a = GetMaxComponents(K); - const uint32_t blob_size_in_words = blob_size / 4; - const uint32_t components_b = GetMaxComponents(blob_size_in_words); - const uint32_t components_b_with_u32 = components_b * kU32Components; - const uint32_t num_N_tile = CeilDiv(N, params.tile_size); - const uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; - - MatMulNBitsProgram program{params.tile_size, - nbits, - has_zero_points, - has_bias, - has_weight_idx, - has_weight_idx_indirect, - single_scale_weights, - params.tile_size_k_vec}; - program.SetWorkgroupSize(params.workgroup_size); - program.SetDispatchGroupSize(num_N_tile, M, batch_count); - program - .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, - {scales, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank}) - .AddUniformVariables({{M}, - {N}, - {K}, - {K / components_a}, - {K_of_b}, - {block_size}, - {n_blocks_per_col}, - {zero_blocks_per_col}, - {num_N_tile}, - {batch_count}, - {weight_index}}) - .CacheHint(nbits, has_zero_points, single_scale_weights, has_bias, has_weight_idx, has_weight_idx_indirect, params.tile_size_k_vec); - if (has_zero_points) { - program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); - } - if (has_bias) { - program.AddInput({bias, ProgramTensorMetadataDependency::None}); - } - if (has_weight_idx_indirect) { - program.AddInput({weight_index_indirect, ProgramTensorMetadataDependency::None}); - } - - ORT_RETURN_IF_ERROR(context.RunProgram(program)); - if (wait_for_completion) { - return context.FlushAndWait(); - } - return Status::OK(); -} - -MatMulNBitsTuneParams GetTunedMatMulNBitsParams(const Tensor* a, - const Tensor* b, - const Tensor* scales, - const Tensor* zero_points, - const Tensor* bias, - uint32_t batch_count, - uint32_t M, - uint32_t N, - uint32_t K, - uint32_t block_size, - uint32_t n_blocks_per_col, - uint32_t zero_blocks_per_col, - uint32_t blob_size, - uint32_t nbits, - bool has_zero_points, - bool has_bias, - bool has_weight_idx, - bool has_weight_idx_indirect, - bool single_scale_weights, - onnxruntime::webgpu::ComputeContext& context, - Tensor* y, - uint32_t weight_index, - const Tensor* weight_index_indirect) { - const bool is_fp16 = y->DataType() == DataTypeImpl::GetType(); - const std::string tune_key = MakeMatMulNBitsTuneKey(context, M, N, K, block_size, nbits, single_scale_weights, is_fp16); - - { - std::lock_guard lock(g_matmul_nbits_tune_mutex); - const auto it = g_matmul_nbits_tuned_params.find(tune_key); - if (it != g_matmul_nbits_tuned_params.end()) { - return it->second; - } - } - - const MatMulNBitsTuneParams default_params = GetDefaultMatMulNBitsTuneParams(context); - MatMulNBitsTuneParams best_params = default_params; - double default_seconds = 0.0; - double best_seconds = 0.0; - - bool default_failed = false; - for (int i = 0; i < kMatMulNBitsTuneWarmupRuns; ++i) { - const Status status = RunDefaultMatMulNBitsProgram(a, - b, - scales, - zero_points, - bias, - batch_count, - M, - N, - K, - block_size, - n_blocks_per_col, - zero_blocks_per_col, - blob_size, - nbits, - has_zero_points, - has_bias, - has_weight_idx, - has_weight_idx_indirect, - single_scale_weights, - context, - y, - weight_index, - weight_index_indirect, - default_params, - true); - if (!status.IsOK()) { - default_failed = true; - break; - } - } - - if (!default_failed) { - for (int i = 0; i < kMatMulNBitsTuneMeasuredRuns; ++i) { - const auto start = std::chrono::steady_clock::now(); - const Status status = RunDefaultMatMulNBitsProgram(a, - b, - scales, - zero_points, - bias, - batch_count, - M, - N, - K, - block_size, - n_blocks_per_col, - zero_blocks_per_col, - blob_size, - nbits, - has_zero_points, - has_bias, - has_weight_idx, - has_weight_idx_indirect, - single_scale_weights, - context, - y, - weight_index, - weight_index_indirect, - default_params, - true); - const auto end = std::chrono::steady_clock::now(); - if (!status.IsOK()) { - default_failed = true; - break; - } - default_seconds += std::chrono::duration(end - start).count(); - } - } - - if (default_failed) { - LOGS(context.Logger(), WARNING) << "MatMulNBits tuner kept default params for " << tune_key - << " because the baseline measurement failed" - << ": tile_size_k_vec=" << default_params.tile_size_k_vec - << ", workgroup_size=" << default_params.workgroup_size - << ", tile_size=" << default_params.tile_size; - - std::lock_guard lock(g_matmul_nbits_tune_mutex); - g_matmul_nbits_tuned_params.insert_or_assign(tune_key, default_params); - return default_params; - } - - best_seconds = default_seconds; - - for (const auto& candidate : kMatMulNBitsTuneCandidates) { - if (!IsValidTuneCandidate(context, candidate)) { - continue; - } - - if (candidate.tile_size_k_vec == default_params.tile_size_k_vec && - candidate.workgroup_size == default_params.workgroup_size && - candidate.tile_size == default_params.tile_size) { - continue; - } - - bool candidate_failed = false; - for (int i = 0; i < kMatMulNBitsTuneWarmupRuns; ++i) { - const Status status = RunDefaultMatMulNBitsProgram(a, - b, - scales, - zero_points, - bias, - batch_count, - M, - N, - K, - block_size, - n_blocks_per_col, - zero_blocks_per_col, - blob_size, - nbits, - has_zero_points, - has_bias, - has_weight_idx, - has_weight_idx_indirect, - single_scale_weights, - context, - y, - weight_index, - weight_index_indirect, - candidate, - true); - if (!status.IsOK()) { - candidate_failed = true; - break; - } - } - if (candidate_failed) { - continue; - } - - double candidate_seconds = 0.0; - for (int i = 0; i < kMatMulNBitsTuneMeasuredRuns; ++i) { - const auto start = std::chrono::steady_clock::now(); - const Status status = RunDefaultMatMulNBitsProgram(a, - b, - scales, - zero_points, - bias, - batch_count, - M, - N, - K, - block_size, - n_blocks_per_col, - zero_blocks_per_col, - blob_size, - nbits, - has_zero_points, - has_bias, - has_weight_idx, - has_weight_idx_indirect, - single_scale_weights, - context, - y, - weight_index, - weight_index_indirect, - candidate, - true); - const auto end = std::chrono::steady_clock::now(); - if (!status.IsOK()) { - candidate_failed = true; - break; - } - candidate_seconds += std::chrono::duration(end - start).count(); - } - - if (!candidate_failed && candidate_seconds < best_seconds) { - best_seconds = candidate_seconds; - best_params = candidate; - } - } - - const double improvement_ratio = default_seconds > 0.0 - ? (default_seconds - best_seconds) / default_seconds - : 0.0; - if (improvement_ratio <= kMatMulNBitsTuneMinImprovementRatio) { - best_params = default_params; - best_seconds = default_seconds; - } - - LOGS(context.Logger(), WARNING) << "MatMulNBits tuner selected params for " << tune_key - << ": tile_size_k_vec=" << best_params.tile_size_k_vec - << ", workgroup_size=" << best_params.workgroup_size - << ", tile_size=" << best_params.tile_size - << ", default_measured_seconds=" << default_seconds - << ", selected_measured_seconds=" << best_seconds - << ", improvement_ratio=" << improvement_ratio; - - std::lock_guard lock(g_matmul_nbits_tune_mutex); - g_matmul_nbits_tuned_params.insert_or_assign(tune_key, best_params); - return best_params; -} } // namespace ONNX_OPERATOR_KERNEL_EX( @@ -497,6 +68,43 @@ Status MatMulNBitsWideTileProgram::GenerateShaderCode(ShaderHelper& shader) cons WGSL_TEMPLATE_VARIABLE(scales, scales)); } +Status MatMulNBitsM1Program::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias); + const auto& b = shader.AddInput("input_b"); + const auto& scales_b = shader.AddInput("scales_b"); + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + const auto& output = shader.AddOutput("output", ShaderUsage::UseElementTypeAlias); + + const uint32_t components_a = a.NumComponents(); + const uint32_t components_b = b.NumComponents() / 4; + const uint32_t elements_in_value_b = components_b * 8u; + const uint32_t a_length_per_lane = elements_in_value_b / components_a; + const uint32_t tile_size_k = tile_size_k_vec_ * elements_in_value_b; + ORT_ENFORCE(tile_size_ % 4u == 0u, "tile_size must be divisible by 4 for MatMulNBitsM1Program."); + + return WGSL_TEMPLATE_APPLY(shader, + "quantization/matmul_nbits_m1.wgsl.template", + WGSL_TEMPLATE_PARAMETER(a_length_per_lane, a_length_per_lane), + WGSL_TEMPLATE_PARAMETER(component_a, components_a), + WGSL_TEMPLATE_PARAMETER(component_b, components_b), + WGSL_TEMPLATE_PARAMETER(elements_in_value_b, elements_in_value_b), + WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_), + WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx_), + WGSL_TEMPLATE_PARAMETER(has_zero_points, false), + WGSL_TEMPLATE_PARAMETER(n_bits, 4), + WGSL_TEMPLATE_PARAMETER(output_type_i32, false), + WGSL_TEMPLATE_PARAMETER(single_scale_weights, single_scale_weights_), + WGSL_TEMPLATE_PARAMETER(subgroup_tile_size, tile_size_ / 4u), + WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k), + WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec_), + WGSL_TEMPLATE_VARIABLE(a, a), + WGSL_TEMPLATE_VARIABLE(b, b), + WGSL_TEMPLATE_VARIABLE(output, output), + WGSL_TEMPLATE_VARIABLE(scales_b, scales_b)); +} + // Apply similar idea with DP4AMatMulNBitsSmallMProgram algorithm. Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& a = shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias); @@ -653,7 +261,7 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, #if !defined(__wasm__) int32_t subgroup_matrix_config_index = -1; // apple|intel - Experimental dawn support for subgroup matrix matmul. - if (M >= kMinMForTileOptimization && (context.AdapterInfo().vendor == std::string_view{"apple"} || context.AdapterInfo().vendor == std::string_view{"intel"}) && + if ((M >= kMinMForTileOptimization && !has_weight_idx_indirect) && (context.AdapterInfo().vendor == std::string_view{"apple"} || context.AdapterInfo().vendor == std::string_view{"intel"}) && CanApplySubgroupMatrixMatMulNBits(context, accuracy_level, block_size, batch_count, N, K, static_cast(nbits), y->DataType() == DataTypeImpl::GetType(), subgroup_matrix_config_index)) { return ApplySubgroupMatrixMatMulNBits(a, b, scales, zero_points, bias, M, N, K, static_cast(nbits), zero_blocks_per_col, subgroup_matrix_config_index, context, y, weight_index, weight_index_indirect); } @@ -661,14 +269,15 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, // On FP32 only GPUs and Qualcomm GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M. // DP4A Q2 path now supports custom zero points via a 1024-entry LUT (4 zero-point sections × 256 byte values). - if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && + if (((M >= kMinMForTileOptimization && !has_weight_idx_indirect) || y->DataType() == DataTypeImpl::GetType() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)) { return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, batch_count, M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast(nbits), context, y, weight_index, weight_index_indirect); } // WideTileProgram // This program is optimized for Block32 prefill using Tile16x128. - const bool use_wide_tile_program = block_size == 32 && + const bool use_wide_tile_program = !has_weight_idx_indirect && + block_size == 32 && components_a == 4 && components_b == 4 && nbits != 2 && @@ -730,59 +339,95 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, return context.RunProgram(program); } - MatMulNBitsTuneParams params{}; - if (ShouldTuneDefaultMatMulNBitsProgram(context, batch_count, M, N, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect)) { - params = GetTunedMatMulNBitsParams(a, - b, - scales, - zero_points, - bias, - batch_count, - M, - N, - K, - block_size, - n_blocks_per_col, - zero_blocks_per_col, - blob_size, - static_cast(nbits), - has_zero_points, - has_bias, - has_weight_idx, - has_weight_idx_indirect, - single_scale_weights, - context, - y, - weight_index, - weight_index_indirect); - } else { - params = GetDefaultMatMulNBitsTuneParams(context); + // Disabled for now so decode uses the existing generic fallback path that is + // already deployed in production. We can reintroduce M1-specific ideas when + // tuning the fused implementation. + // const bool use_m1_subgroup_program = M == 1 && + // batch_count >= 1 && + // static_cast(nbits) == 4u && + // !has_zero_points && + // !has_weight_idx_indirect && + // context.AdapterInfo().vendor == std::string_view{"nvidia"} && + // context.HasFeature(wgpu::FeatureName::Subgroups); + // + // if (use_m1_subgroup_program) { + // constexpr uint32_t workgroup_size = 128; + // constexpr uint32_t tile_size = 8; + // constexpr uint32_t tile_size_k_vec = 32; + // constexpr uint32_t kU32Components = 4; + // const uint32_t components_b_with_u32 = components_b * kU32Components; + // const uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; + // + // MatMulNBitsM1Program program{tile_size, + // has_bias, + // has_weight_idx, + // single_scale_weights, + // tile_size_k_vec}; + // program.SetWorkgroupSize(workgroup_size); + // const uint32_t num_N_tile = CeilDiv(N, tile_size); + // program.SetDispatchGroupSize(num_N_tile, 1, batch_count); + // program + // .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}, + // {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + // {scales, ProgramTensorMetadataDependency::TypeAndRank}}) + // .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank}) + // .AddUniformVariables({{N}, + // {K}, + // {K / components_a}, + // {K_of_b}, + // {block_size}, + // {n_blocks_per_col}, + // {num_N_tile}, + // {batch_count}, + // {weight_index}}) + // .CacheHint(has_bias, has_weight_idx, single_scale_weights, tile_size_k_vec); + // if (has_bias) { + // program.AddInput({bias, ProgramTensorMetadataDependency::None}); + // } + // return context.RunProgram(program); + // } + + // Use tile_size_k_vec=32 by default for better K-dimension parallelism. + // Intel devices use 16 as they have different subgroup/cache characteristics. + const uint32_t tile_size_k_vec = + (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u; + + constexpr uint32_t workgroup_size = 128; + constexpr uint32_t tile_size = 8; + constexpr uint32_t kU32Components = 4; + uint32_t components_b_with_u32 = components_b * kU32Components; + uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; + MatMulNBitsProgram program{tile_size, static_cast(nbits), has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect, single_scale_weights, tile_size_k_vec}; + program.SetWorkgroupSize(workgroup_size); + uint32_t num_N_tile = (N + tile_size - 1) / tile_size; + program.SetDispatchGroupSize(num_N_tile, M, batch_count); + program + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank}) + .AddUniformVariables({{M}, + {N}, + {K}, + {K / components_a}, + {K_of_b}, + {block_size}, + {n_blocks_per_col}, + {zero_blocks_per_col}, + {num_N_tile}, + {batch_count}, + {weight_index}}) + .CacheHint(nbits, has_zero_points, single_scale_weights, has_bias, has_weight_idx, has_weight_idx_indirect, tile_size_k_vec); + if (has_zero_points) { + program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } - - return RunDefaultMatMulNBitsProgram(a, - b, - scales, - zero_points, - bias, - batch_count, - M, - N, - K, - block_size, - n_blocks_per_col, - zero_blocks_per_col, - blob_size, - static_cast(nbits), - has_zero_points, - has_bias, - has_weight_idx, - has_weight_idx_indirect, - single_scale_weights, - context, - y, - weight_index, - weight_index_indirect, - params); + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::None}); + } + if (has_weight_idx_indirect) { + program.AddInput({weight_index_indirect, ProgramTensorMetadataDependency::None}); + } + return context.RunProgram(program); } } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index 295a0fb90dd2a..a009a7f757b55 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -39,6 +39,40 @@ class MatMulNBitsWideTileProgram final : public Program { + public: + MatMulNBitsM1Program(uint32_t tile_size, + bool has_bias, + bool has_weight_idx, + bool single_scale_weights, + uint32_t tile_size_k_vec = 32) + : Program{"MatMulNBitsM1"}, + tile_size_(tile_size), + has_bias_(has_bias), + has_weight_idx_{has_weight_idx}, + single_scale_weights_(single_scale_weights), + tile_size_k_vec_(tile_size_k_vec) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"K_of_a", ProgramUniformVariableDataType::Uint32}, + {"K_of_b", ProgramUniformVariableDataType::Uint32}, + {"block_size", ProgramUniformVariableDataType::Uint32}, + {"blocks_per_col", ProgramUniformVariableDataType::Uint32}, + {"num_N_tile", ProgramUniformVariableDataType::Uint32}, + {"batch_count", ProgramUniformVariableDataType::Uint32}, + {"weight_idx", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t tile_size_; + bool has_bias_; + bool has_weight_idx_; + bool single_scale_weights_; + uint32_t tile_size_k_vec_; +}; + class MatMulNBitsProgram final : public Program { public: MatMulNBitsProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect, bool single_scale_weights, uint32_t tile_size_k_vec = 16) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_m1.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_m1.wgsl.template new file mode 100644 index 0000000000000..149c89ae8017a --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_m1.wgsl.template @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param a_length_per_tile +#param a_length_per_lane +#param component_a +#param component_b +#param elements_in_value_b +#param has_bias +#param has_weight_idx +#param single_scale_weights +#param subgroup_tile_size +#param tile_size_k +#param tile_size_k_vec + +#use .getByOffset .setByOffset + +#include "quantization/matmul_nbits_zero_pt.wgsl.template" + +fn load_lane_a(batch: u32, kidx: u32, lane_col_base: u32, col: u32) -> input_a_value_t { + let k_offset = kidx / component_a + lane_col_base + col; + if (batch < uniforms.batch_count && k_offset < uniforms.K_of_a) { + return a.getByOffset(batch * uniforms.K_of_a + k_offset); + } + + return input_a_value_t(0); +} + +$MAIN { + let batch = workgroup_id.z; + let subgroup_id = local_id.x / tile_size_k_vec; + let lane_id = local_id.x % tile_size_k_vec; + let b_global_base = workgroup_id.x * subgroup_tile_size * 4u + subgroup_id * subgroup_tile_size; + +#if has_weight_idx + let actual_weight_idx = uniforms.weight_idx; + let b_base_offset = actual_weight_idx * uniforms.K_of_b * uniforms.N; +#if single_scale_weights + let b_scale_offset = actual_weight_idx; +#else + let b_scale_offset = actual_weight_idx * uniforms.N * uniforms.blocks_per_col; +#endif +#else + const b_base_offset : u32 = 0u; + const b_scale_offset : u32 = 0u; + const actual_weight_idx : u32 = 0u; +#endif + +#if has_bias + let b_bias_offset = actual_weight_idx * uniforms.N; +#endif + + var accum : array; + for (var output_idx = 0u; output_idx < subgroup_tile_size; output_idx++) { + accum[output_idx] = output_element_t(0); + } + +#if single_scale_weights + let block_idx = 0u; +#endif + + let lane_col_base = lane_id * a_length_per_lane; + + for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k) { + var lane_a : array; + for (var col = 0u; col < a_length_per_lane; col++) { + lane_a[col] = load_lane_a(batch, kidx, lane_col_base, col); + } + + let k_offset = kidx / elements_in_value_b + lane_id; + if (k_offset < uniforms.K_of_b) { + for (var output_idx = 0u; output_idx < subgroup_tile_size; output_idx++) { + let b_global = b_global_base + output_idx; + if (b_global >= uniforms.N) { + continue; + } + +#if !single_scale_weights + let block_idx = (kidx + lane_id * elements_in_value_b) / uniforms.block_size; +#endif +#if single_scale_weights + let scale_b = scales_b.getByOffset(b_scale_offset); +#else + let scale_b = scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx + b_scale_offset); +#endif + let zero = mm_read_zero(b_global, block_idx, uniforms.N, 0u); + let b_value = b.getByOffset(b_global * uniforms.K_of_b + k_offset + b_base_offset); + + var sum = output_element_t(0); + var a_offset = 0u; + +#if component_b == 1 + let b_value_lower = vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(zero); + let b_value_upper = vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(zero); + let b0 = vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]) * scale_b; + let b1 = vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]) * scale_b; +#if component_a == 1 + sum += dot(vec4(lane_a[a_offset], lane_a[a_offset + 1], lane_a[a_offset + 2], lane_a[a_offset + 3]), b0) + + dot(vec4(lane_a[a_offset + 4], lane_a[a_offset + 5], lane_a[a_offset + 6], lane_a[a_offset + 7]), b1); +#elif component_a == 2 + sum += dot(vec4(lane_a[a_offset], lane_a[a_offset + 1]), b0) + + dot(vec4(lane_a[a_offset + 2], lane_a[a_offset + 3]), b1); +#elif component_a == 4 + sum += dot(lane_a[a_offset], b0) + dot(lane_a[a_offset + 1], b1); +#endif +#else + for (var i = 0u; i < component_b; i++) { + let b_value_lower = vec4(unpack4xU8(b_value[i] & 0x0F0F0F0Fu)) - vec4(zero); + let b_value_upper = vec4(unpack4xU8((b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4(zero); + let b0 = vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]) * scale_b; + let b1 = vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]) * scale_b; +#if component_a == 1 + sum += dot(vec4(lane_a[a_offset], lane_a[a_offset + 1], lane_a[a_offset + 2], lane_a[a_offset + 3]), b0) + + dot(vec4(lane_a[a_offset + 4], lane_a[a_offset + 5], lane_a[a_offset + 6], lane_a[a_offset + 7]), b1); + a_offset += 8; +#elif component_a == 2 + sum += dot(vec4(lane_a[a_offset], lane_a[a_offset + 1]), b0) + + dot(vec4(lane_a[a_offset + 2], lane_a[a_offset + 3]), b1); + a_offset += 4; +#elif component_a == 4 + sum += dot(lane_a[a_offset], b0) + dot(lane_a[a_offset + 1], b1); + a_offset += 2; +#endif + } +#endif + + accum[output_idx] += sum; + } + } + } + + for (var output_idx = 0u; output_idx < subgroup_tile_size; output_idx++) { + var reduced = accum[output_idx]; + if (lane_id < 16u) { + reduced += subgroupShuffle(reduced, lane_id + 16u); + } + if (lane_id < 8u) { + reduced += subgroupShuffle(reduced, lane_id + 8u); + } + if (lane_id < 4u) { + reduced += subgroupShuffle(reduced, lane_id + 4u); + } + if (lane_id < 2u) { + reduced += subgroupShuffle(reduced, lane_id + 2u); + } + if (lane_id < 1u) { + reduced += subgroupShuffle(reduced, lane_id + 1u); + } + + if (lane_id == 0u) { + let b_global = b_global_base + output_idx; + if (b_global < uniforms.N) { + var output_value = reduced; +#if has_bias + output_value += bias[b_global + b_bias_offset]; +#endif + output.setByOffset(batch * uniforms.N + b_global, output_value); + } + } + } +} // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc new file mode 100644 index 0000000000000..13bb7079462b2 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc @@ -0,0 +1,330 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/quantization/matmul_nbits_silu.h" + +#include "contrib_ops/webgpu/quantization/matmul_nbits.h" +#include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" +#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" +#include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +namespace { + +constexpr unsigned int kMinMForTileOptimization = 4; + +constexpr uint32_t kFusedDecodeFastPathBits = 4u; +constexpr uint32_t kFusedDecodeFastPathBlockSize = 32u; + +bool IsFusedDecodeFastPathEnabled() { + return true; +} + +Status WouldApplyGenericMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t accuracy_level, + int64_t nbits, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + bool& would_apply_generic) { + TensorShape b_shape({N_op, K_op}); + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + + const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size()); + const uint32_t M = onnxruntime::narrow(helper.M()); + const uint32_t N = onnxruntime::narrow(helper.N()); + const uint32_t K = onnxruntime::narrow(helper.K()); + const uint32_t block_size = onnxruntime::narrow(block_size_op); + + const bool single_scale_weights = (block_size == K * N); + const uint32_t block_size_per_col = single_scale_weights ? K : block_size; + const uint32_t blob_size = (block_size_per_col / 8) * static_cast(nbits); + const uint32_t blob_size_in_words = blob_size / 4; + const uint32_t components_a = GetMaxComponents(K); + const uint32_t components_b = GetMaxComponents(blob_size_in_words); + +#if !defined(__wasm__) + int32_t subgroup_matrix_config_index = -1; + const bool would_apply_subgroup = + (M >= kMinMForTileOptimization) && + (context.AdapterInfo().vendor == std::string_view{"apple"} || + context.AdapterInfo().vendor == std::string_view{"intel"}) && + CanApplySubgroupMatrixMatMulNBits(context, + accuracy_level, + block_size, + batch_count, + N, + K, + static_cast(nbits), + y->DataType() == DataTypeImpl::GetType(), + subgroup_matrix_config_index); + if (would_apply_subgroup) { + would_apply_generic = false; + return Status::OK(); + } +#endif + + const bool would_apply_dp4a = + ((M >= kMinMForTileOptimization || + y->DataType() == DataTypeImpl::GetType() || + context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && + CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)); + if (would_apply_dp4a) { + would_apply_generic = false; + return Status::OK(); + } + + const bool would_apply_wide_tile = block_size == 32 && + components_a == 4 && + components_b == 4 && + nbits != 2 && + M >= kMinMForTileOptimization; + would_apply_generic = !would_apply_wide_tile; + return Status::OK(); +} + +class MatMulNBitsSiluMulDecodeProgram final : public Program { + public: + MatMulNBitsSiluMulDecodeProgram(uint32_t tile_size, + bool has_gate_bias, + bool has_up_bias, + bool single_scale_weights, + uint32_t tile_size_k_vec) + : Program{"MatMulNBitsSiluMulDecode"}, + tile_size_(tile_size), + has_gate_bias_(has_gate_bias), + has_up_bias_(has_up_bias), + single_scale_weights_(single_scale_weights), + tile_size_k_vec_(tile_size_k_vec) {} + + Status GenerateShaderCode(ShaderHelper& shader) const override { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias); + const auto& gate_b = shader.AddInput("gate_b"); + const auto& gate_scales_b = shader.AddInput("gate_scales_b"); + const auto& up_b = shader.AddInput("up_b"); + const auto& up_scales_b = shader.AddInput("up_scales_b"); + if (has_gate_bias_) { + shader.AddInput("gate_bias", ShaderUsage::UseUniform); + } + if (has_up_bias_) { + shader.AddInput("up_bias", ShaderUsage::UseUniform); + } + const auto& output = shader.AddOutput("output", ShaderUsage::UseElementTypeAlias); + + const uint32_t components_a = a.NumComponents(); + const uint32_t components_b = gate_b.NumComponents() / 4; + const uint32_t tile_size_k_vec = tile_size_k_vec_; + const uint32_t elements_in_value_b = components_b * 8u; + const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b; + const uint32_t a_length_per_tile = tile_size_k / components_a; + const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec; + + return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_silu_mul.wgsl.template", + WGSL_TEMPLATE_PARAMETER(a_length_per_tile, a_length_per_tile), + WGSL_TEMPLATE_PARAMETER(component_a, components_a), + WGSL_TEMPLATE_PARAMETER(component_b, components_b), + WGSL_TEMPLATE_PARAMETER(elements_in_value_b, elements_in_value_b), + WGSL_TEMPLATE_PARAMETER(has_gate_bias, has_gate_bias_), + WGSL_TEMPLATE_PARAMETER(has_up_bias, has_up_bias_), + WGSL_TEMPLATE_PARAMETER(single_scale_weights, single_scale_weights_), + WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count), + WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), + WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k), + WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), + WGSL_TEMPLATE_VARIABLE(a, a), + WGSL_TEMPLATE_VARIABLE(gate_b, gate_b), + WGSL_TEMPLATE_VARIABLE(gate_scales_b, gate_scales_b), + WGSL_TEMPLATE_VARIABLE(output, output), + WGSL_TEMPLATE_VARIABLE(up_b, up_b), + WGSL_TEMPLATE_VARIABLE(up_scales_b, up_scales_b)); + } + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"K_of_a", ProgramUniformVariableDataType::Uint32}, + {"K_of_b", ProgramUniformVariableDataType::Uint32}, + {"block_size", ProgramUniformVariableDataType::Uint32}, + {"blocks_per_col", ProgramUniformVariableDataType::Uint32}, + {"num_N_tile", ProgramUniformVariableDataType::Uint32}, + {"batch_count", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t tile_size_; + bool has_gate_bias_; + bool has_up_bias_; + bool single_scale_weights_; + uint32_t tile_size_k_vec_; +}; + +class MatMulNBitsSiluMulProgram final : public Program { + public: + MatMulNBitsSiluMulProgram() : Program{"MatMulNBitsSiluMul"} {} + + Status GenerateShaderCode(ShaderHelper& shader) const override { + const auto& gate = shader.AddInput("gate", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& up = shader.AddInput("up", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << "let gate_value = " << gate.GetByOffset("global_idx") << ";\n" + << "let up_value = " << up.GetByOffset("global_idx") << ";\n" + << "let one = output_value_t(1.0);\n" + << "let silu_value = gate_value * (one / (one + exp(-gate_value)));\n" + << output.SetByOffset("global_idx", "silu_value * up_value"); + + return Status::OK(); + } + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); +}; + +} // namespace + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBitsSiluMul, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", WebGpuSupportedFloatTypes()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBitsSiluMul); + +Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* a = context.Input(0); + const Tensor* gate_b = context.Input(1); + const Tensor* gate_scales = context.Input(2); + const Tensor* gate_bias = context.Input(3); + const Tensor* up_b = context.Input(4); + const Tensor* up_scales = context.Input(5); + const Tensor* up_bias = context.Input(6); + + MatMulComputeHelper helper; + TensorShape b_shape({N_, K_}); + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + const auto output_shape = helper.OutputShape(); + const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size()); + const uint32_t M = onnxruntime::narrow(helper.M()); + const uint32_t N = onnxruntime::narrow(helper.N()); + const uint32_t K = onnxruntime::narrow(helper.K()); + const uint32_t block_size = onnxruntime::narrow(block_size_); + + Tensor* y = context.Output(0, output_shape); + const uint32_t data_size = onnxruntime::narrow(y->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + bool gate_would_use_generic_matmul = false; + bool up_would_use_generic_matmul = false; + ORT_RETURN_IF_ERROR(WouldApplyGenericMatMulNBitsInCurrentDispatch(a, + K_, + N_, + block_size_, + accuracy_level_, + bits_, + context, + y, + gate_would_use_generic_matmul)); + ORT_RETURN_IF_ERROR(WouldApplyGenericMatMulNBitsInCurrentDispatch(a, + K_, + N_, + block_size_, + accuracy_level_, + bits_, + context, + y, + up_would_use_generic_matmul)); + + if (IsFusedDecodeFastPathEnabled() && M == 1 && bits_ == kFusedDecodeFastPathBits && + block_size == kFusedDecodeFastPathBlockSize && gate_would_use_generic_matmul && + up_would_use_generic_matmul) { + ORT_ENFORCE(bits_ == kFusedDecodeFastPathBits, + "MatMulNBitsSiluMulDecodeProgram is specialized for 4-bit weights only."); + ORT_ENFORCE(block_size == kFusedDecodeFastPathBlockSize, + "MatMulNBitsSiluMulDecodeProgram is specialized for block_size=32 only."); + + const bool has_gate_bias = gate_bias != nullptr; + const bool has_up_bias = up_bias != nullptr; + const bool single_scale_weights = (block_size == K * N); + const uint32_t block_size_per_col = single_scale_weights ? K : block_size; + const uint32_t n_blocks_per_col = (K + block_size_per_col - 1) / block_size_per_col; + const uint32_t blob_size = (block_size_per_col / 8) * onnxruntime::narrow(bits_); + const uint32_t blob_size_in_words = blob_size / 4; + const uint32_t components_a = GetMaxComponents(K); + const uint32_t components_b = GetMaxComponents(blob_size_in_words); + constexpr uint32_t kU32Components = 4; + const uint32_t components_b_with_u32 = components_b * kU32Components; + const uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; + constexpr uint32_t workgroup_size = 128; + constexpr uint32_t tile_size = 8; + const uint32_t tile_size_k_vec = + (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u; + const uint32_t num_N_tile = CeilDiv(N, tile_size); + + MatMulNBitsSiluMulDecodeProgram program{tile_size, + has_gate_bias, + has_up_bias, + single_scale_weights, + tile_size_k_vec}; + program.SetWorkgroupSize(workgroup_size); + program.SetDispatchGroupSize(num_N_tile, 1, batch_count); + program + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}, + {gate_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + {gate_scales, ProgramTensorMetadataDependency::TypeAndRank}, + {up_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + {up_scales, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank}) + .AddUniformVariables({{N}, + {K}, + {K / components_a}, + {K_of_b}, + {block_size}, + {n_blocks_per_col}, + {num_N_tile}, + {batch_count}}) + .CacheHint(single_scale_weights, has_gate_bias, has_up_bias, tile_size_k_vec, "decode_4bit"); + if (has_gate_bias) { + program.AddInput({gate_bias, ProgramTensorMetadataDependency::None}); + } + if (has_up_bias) { + program.AddInput({up_bias, ProgramTensorMetadataDependency::None}); + } + + return context.RunProgram(program); + } + + Tensor gate_output = context.CreateGPUTensor(a->DataType(), output_shape); + Tensor up_output = context.CreateGPUTensor(a->DataType(), output_shape); + + ORT_RETURN_IF_ERROR(ApplyMatMulNBits(a, gate_b, gate_scales, nullptr, gate_bias, K_, N_, block_size_, accuracy_level_, bits_, context, &gate_output)); + ORT_RETURN_IF_ERROR(ApplyMatMulNBits(a, up_b, up_scales, nullptr, up_bias, K_, N_, block_size_, accuracy_level_, bits_, context, &up_output)); + + const uint32_t vec_size = (data_size + 3u) / 4u; + MatMulNBitsSiluMulProgram program; + program + .AddInputs({{&gate_output, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, 4}, + {&up_output, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, 4}}) + .AddOutput({y, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({vec_size}); + + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.h new file mode 100644 index 0000000000000..476a76c72fa34 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class MatMulNBitsSiluMul final : public WebGpuKernel { + public: + explicit MatMulNBitsSiluMul(const OpKernelInfo& info) : WebGpuKernel(info) { + K_ = info.GetAttr("K"); + N_ = info.GetAttr("N"); + block_size_ = info.GetAttr("block_size"); + bits_ = info.GetAttr("bits"); + accuracy_level_ = info.GetAttrOrDefault("accuracy_level", 4); + ORT_ENFORCE(bits_ == 4 || bits_ == 8 || bits_ == 2, + "Only 4b/8b/2b quantization is supported for MatMulNBitsSiluMul op."); + } + + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t accuracy_level_; + int64_t bits_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul.wgsl.template new file mode 100644 index 0000000000000..3976267b0a75b --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul.wgsl.template @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param a_length_per_tile +#param component_a +#param component_b +#param elements_in_value_b +#param single_scale_weights +#param sub_tile_count +#param tile_size_k_vec +#param tile_size_k +#param tile_size +#param has_gate_bias +#param has_up_bias + +#use .getByOffset .setByOffset + +var tile_A : array; +var gate_inter_results : array, tile_size>; +var up_inter_results : array, tile_size>; + +const default_zero_point = output_element_t(8); + +fn loadSHMA(batch: u32, kidx: u32, col: u32) +{ + let k_offset = kidx / component_a + col; + if (batch < uniforms.batch_count && k_offset < uniforms.K_of_a) { + tile_A[col] = a.getByOffset(batch * uniforms.K_of_a + k_offset); + } else { + tile_A[col] = input_a_value_t(0); + } +} + +$MAIN { + let batch = workgroup_idx / uniforms.num_N_tile; + let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size; + + let idx = local_idx % tile_size_k_vec; + let idy = local_idx / tile_size_k_vec; + + if (local_idx < tile_size) { + for (var b = 0u; b < tile_size_k_vec; b++) { + gate_inter_results[local_idx][b] = output_element_t(0); + up_inter_results[local_idx][b] = output_element_t(0); + } + } + workgroupBarrier(); + +#if single_scale_weights + let gate_scale_b = gate_scales_b.getByOffset(0); + let up_scale_b = up_scales_b.getByOffset(0); + let block_idx = 0u; +#endif + + for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k) + { + for (var id = local_idx; id < a_length_per_tile; id += workgroup_size_x) + { + loadSHMA(batch, kidx, id); + } + workgroupBarrier(); + + for (var local_row_offset = 0u; local_row_offset < tile_size; local_row_offset += sub_tile_count) + { + let b_global = b_global_base + local_row_offset + idy; + let k_offset = kidx / elements_in_value_b + idx; + if (b_global < uniforms.N && k_offset < uniforms.K_of_b) + { +#if !single_scale_weights + let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size; + let gate_scale_b = gate_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx); + let up_scale_b = up_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx); +#endif + let gate_b_value = gate_b.getByOffset(b_global * uniforms.K_of_b + k_offset); + let up_b_value = up_b.getByOffset(b_global * uniforms.K_of_b + k_offset); + + var gate_sum = output_element_t(0); + var up_sum = output_element_t(0); + var a_offset = idx * (8 / component_a) * component_b; +#if component_b == 1 + let gate_b_value_lower = vec4(unpack4xU8(gate_b_value & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let gate_b_value_upper = vec4(unpack4xU8((gate_b_value >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let gate_b0 = vec4(gate_b_value_lower[0], gate_b_value_upper[0], gate_b_value_lower[1], gate_b_value_upper[1]) * gate_scale_b; + let gate_b1 = vec4(gate_b_value_lower[2], gate_b_value_upper[2], gate_b_value_lower[3], gate_b_value_upper[3]) * gate_scale_b; + let up_b_value_lower = vec4(unpack4xU8(up_b_value & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let up_b_value_upper = vec4(unpack4xU8((up_b_value >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let up_b0 = vec4(up_b_value_lower[0], up_b_value_upper[0], up_b_value_lower[1], up_b_value_upper[1]) * up_scale_b; + let up_b1 = vec4(up_b_value_lower[2], up_b_value_upper[2], up_b_value_lower[3], up_b_value_upper[3]) * up_scale_b; +#if component_a == 1 + let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]); + let a1 = vec4(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]); + gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); + up_sum += dot(a0, up_b0) + dot(a1, up_b1); +#elif component_a == 2 + let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1]); + let a1 = vec4(tile_A[a_offset + 2], tile_A[a_offset + 3]); + gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); + up_sum += dot(a0, up_b0) + dot(a1, up_b1); +#elif component_a == 4 + let a0 = tile_A[a_offset]; + let a1 = tile_A[a_offset + 1]; + gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); + up_sum += dot(a0, up_b0) + dot(a1, up_b1); +#endif +#else + for (var i = 0u; i < component_b; i++) { + let gate_b_value_lower = vec4(unpack4xU8(gate_b_value[i] & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let gate_b_value_upper = vec4(unpack4xU8((gate_b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let gate_b0 = vec4(gate_b_value_lower[0], gate_b_value_upper[0], gate_b_value_lower[1], gate_b_value_upper[1]) * gate_scale_b; + let gate_b1 = vec4(gate_b_value_lower[2], gate_b_value_upper[2], gate_b_value_lower[3], gate_b_value_upper[3]) * gate_scale_b; + let up_b_value_lower = vec4(unpack4xU8(up_b_value[i] & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let up_b_value_upper = vec4(unpack4xU8((up_b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let up_b0 = vec4(up_b_value_lower[0], up_b_value_upper[0], up_b_value_lower[1], up_b_value_upper[1]) * up_scale_b; + let up_b1 = vec4(up_b_value_lower[2], up_b_value_upper[2], up_b_value_lower[3], up_b_value_upper[3]) * up_scale_b; +#if component_a == 1 + let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]); + let a1 = vec4(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]); + gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); + up_sum += dot(a0, up_b0) + dot(a1, up_b1); + a_offset += 8; +#elif component_a == 2 + let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1]); + let a1 = vec4(tile_A[a_offset + 2], tile_A[a_offset + 3]); + gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); + up_sum += dot(a0, up_b0) + dot(a1, up_b1); + a_offset += 4; +#elif component_a == 4 + let a0 = tile_A[a_offset]; + let a1 = tile_A[a_offset + 1]; + gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); + up_sum += dot(a0, up_b0) + dot(a1, up_b1); + a_offset += 2; +#endif + } +#endif + + gate_inter_results[local_row_offset + idy][idx] += gate_sum; + up_inter_results[local_row_offset + idy][idx] += up_sum; + } + } + workgroupBarrier(); + } + + if (batch >= uniforms.batch_count) { + return; + } + + if (local_idx < tile_size) { + var gate_output_value = output_element_t(0); + var up_output_value = output_element_t(0); + for (var b = 0u; b < tile_size_k_vec; b++) { + gate_output_value += gate_inter_results[local_idx][b]; + up_output_value += up_inter_results[local_idx][b]; + } + let b_global = b_global_base + local_idx; + let output_idx = batch * uniforms.N + b_global; + if (b_global < uniforms.N) { +#if has_gate_bias + gate_output_value += gate_bias[b_global]; +#endif +#if has_up_bias + up_output_value += up_bias[b_global]; +#endif + let one = output_element_t(1.0); + let silu_value = gate_output_value * (one / (one + exp(-gate_output_value))); + output.setByOffset(output_idx, silu_value * up_output_value); + } + } +} // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 357eebee714d5..9389be885fbdf 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -22,6 +22,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gr // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBitsSiluMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding); @@ -50,6 +51,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry, bool enable BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 5ccba675b4ecf..869c76dc15a2a 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3616,6 +3616,67 @@ For example, for 4 bits, the first 4 bits are stored in the lower 4 bits of a by } }); + static const char* MatMulNBitsSiluMul_ver1_doc = R"DOC( +MatMulNBitsSiluMul fuses two MatMulNBits projections that share the same input and computes + + Y = SiLU(MatMulNBits(A, gate_weight) + gate_bias) * (MatMulNBits(A, up_weight) + up_bias) + +where SiLU(x) = x * sigmoid(x). + +This operator is intended for decoder MLP patterns such as Qwen-style gate and up projections, but it remains +semantically valid for both prefill and decode because the output shape is the standard MatMul result shape +derived from the runtime shape of A and the shared attributes K and N. +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBitsSiluMul) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(MatMulNBitsSiluMul_ver1_doc) + .Attr("K", "Input feature dimension shared by both quantized weight matrices.", AttributeProto::INT) + .Attr("N", "Output feature dimension shared by both quantized weight matrices.", AttributeProto::INT) + .Attr("bits", "Bit-width used to quantize both weight matrices (valid range: 2~8)", AttributeProto::INT, static_cast(4)) + .Attr("block_size", + "Size of each quantization block along the K dimension. Must be a power of two and >= 16.", + AttributeProto::INT) + .Attr("accuracy_level", + "The minimum accuracy level of input A. It follows the same semantics as MatMulNBits.", + AttributeProto::INT, static_cast(0)) + .Input(0, "A", "The shared input tensor.", "T1") + .Input(1, "gate_B", "Packed uint8 tensor for the gate projection weights.", "T2") + .Input(2, "gate_scales", "Per-block scaling factors for the gate projection.", "T1") + .Input(3, "gate_bias", "Optional bias for the gate projection with shape [N].", "T1", OpSchema::Optional) + .Input(4, "up_B", "Packed uint8 tensor for the up projection weights.", "T2") + .Input(5, "up_scales", "Per-block scaling factors for the up projection.", "T1") + .Input(6, "up_bias", "Optional bias for the up projection with shape [N].", "T1", OpSchema::Optional) + .Output(0, "Y", "The fused SiLU-multiply output tensor.", "T1") + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors.") + .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + + const int64_t in_features = getAttribute(ctx, "K", -1); + const int64_t out_features = getAttribute(ctx, "N", -1); + MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, true); + + for (size_t bias_input_index : {3U, 6U}) { + if (!ctx.hasInput(static_cast(bias_input_index))) { + continue; + } + + if (!hasInputShape(ctx, static_cast(bias_input_index))) { + fail_shape_inference("bias shape must be known"); + } + + const auto& bias_shape = getInputShape(ctx, static_cast(bias_input_index)); + if (bias_shape.dim_size() != 1 || + !bias_shape.dim(0).has_dim_value() || + bias_shape.dim(0).dim_value() != out_features) { + fail_shape_inference("bias shape must be [N] where N = ", out_features); + } + } + }); + static const char* MatMulBnb4_ver1_doc = R"DOC( MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 9ed1d5e9e84fa..642afcb3f0c29 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -55,6 +55,7 @@ #include "core/optimizer/layer_norm_fusion.h" #include "core/optimizer/matmul_activation_fusion.h" #include "core/optimizer/matmul_add_fusion.h" +#include "core/optimizer/matmul_nbits_silu_fusion.h" #include "core/optimizer/matmul_bn_fusion.h" #include "core/optimizer/matmul_integer_to_float.h" #include "core/optimizer/matmul_scale_fusion.h" @@ -99,6 +100,17 @@ namespace onnxruntime::optimizer_utils { +namespace { + +bool IsMatMulNBitsSiluFusionEnabled(const SessionOptions& session_options) { + const auto config_value = session_options.config_options.GetConfigOrDefault( + kOrtSessionOptionsEnableMatMulNBitsSiluFusion, + "0"); + return config_value != "0"; +} + +} // namespace + static void FilterTransformers(InlinedVector>& transformers, const InlinedHashSet& transformers_to_disable) { if (transformers_to_disable.empty()) return; @@ -436,6 +448,9 @@ InlinedVector> GenerateTransformers( #endif transformers.emplace_back(std::make_unique(cpu_ep)); + if (IsMatMulNBitsSiluFusionEnabled(session_options)) { + transformers.emplace_back(std::make_unique(InlinedHashSet{onnxruntime::kWebGpuExecutionProvider})); + } #endif // !defined(DISABLE_CONTRIB_OPS) // The QDQFinalCleanupTransformer must run AFTER other transformers that fuse Q/DQ nodes. Otherwise, their diff --git a/onnxruntime/core/optimizer/matmul_nbits_silu_fusion.cc b/onnxruntime/core/optimizer/matmul_nbits_silu_fusion.cc new file mode 100644 index 0000000000000..f17aa28a3c79d --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_nbits_silu_fusion.cc @@ -0,0 +1,244 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/matmul_nbits_silu_fusion.h" + +#include + +#include "core/graph/graph_utils.h" +#include "core/graph/node_attr_utils.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +namespace { + +bool HasInput(const Node& node, size_t index) { + return index < node.InputDefs().size() && node.InputDefs()[index] != nullptr && !node.InputDefs()[index]->Name().empty(); +} + +const Node* GetInputNode(const Graph& graph, const Node& node, size_t input_index) { + const auto* edge = graph_utils::GetInputEdge(node, static_cast(input_index)); + return edge == nullptr ? nullptr : graph.GetNode(edge->GetNode().Index()); +} + +bool IsSupportedMul(const Node& node) { + return graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14}); +} + +bool IsSupportedSigmoid(const Node& node) { + return graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}); +} + +bool IsMatMulNBitsWithoutZeroPointOrGroupIdx(const Node& node) { + return graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMulNBits", {1}, kMSDomain) && + !HasInput(node, 3) && !HasInput(node, 4); +} + +int64_t GetIntAttr(const Node& node, const char* name, int64_t default_value, bool required = false) { + const auto* attr = graph_utils::GetNodeAttribute(node, name); + if (attr == nullptr) { + ORT_ENFORCE(!required, "Missing required attribute ", name, " on node ", node.Name()); + return default_value; + } + + return attr->i(); +} + +bool HasSingleNonGraphConsumer(const Graph& graph, const Node& node) { + return !graph.NodeProducesGraphOutput(node) && optimizer_utils::CheckOutputEdges(graph, node, 1); +} + +bool IsFuseCandidate(const Graph& graph, + const Node& gate_matmul, + const Node& up_matmul, + const Node& sigmoid, + const Node& silu_mul, + const Node& final_mul) { + if (!IsMatMulNBitsWithoutZeroPointOrGroupIdx(gate_matmul) || !IsMatMulNBitsWithoutZeroPointOrGroupIdx(up_matmul) || + !IsSupportedSigmoid(sigmoid) || !IsSupportedMul(silu_mul) || !IsSupportedMul(final_mul)) { + return false; + } + + if (!HasSingleNonGraphConsumer(graph, up_matmul) || !HasSingleNonGraphConsumer(graph, sigmoid) || + !HasSingleNonGraphConsumer(graph, silu_mul)) { + return false; + } + + if (graph.NodeProducesGraphOutput(gate_matmul) || gate_matmul.GetOutputEdgesCount() != 2) { + return false; + } + + if (gate_matmul.InputDefs().empty() || up_matmul.InputDefs().empty() || + gate_matmul.InputDefs()[0] != up_matmul.InputDefs()[0]) { + return false; + } + + if (sigmoid.InputDefs()[0] != gate_matmul.OutputDefs()[0]) { + return false; + } + + const bool silu_mul_matches = + (silu_mul.InputDefs()[0] == gate_matmul.OutputDefs()[0] && silu_mul.InputDefs()[1] == sigmoid.OutputDefs()[0]) || + (silu_mul.InputDefs()[1] == gate_matmul.OutputDefs()[0] && silu_mul.InputDefs()[0] == sigmoid.OutputDefs()[0]); + if (!silu_mul_matches) { + return false; + } + + const bool final_mul_matches = + (final_mul.InputDefs()[0] == silu_mul.OutputDefs()[0] && final_mul.InputDefs()[1] == up_matmul.OutputDefs()[0]) || + (final_mul.InputDefs()[1] == silu_mul.OutputDefs()[0] && final_mul.InputDefs()[0] == up_matmul.OutputDefs()[0]); + if (!final_mul_matches) { + return false; + } + + const int64_t gate_k = GetIntAttr(gate_matmul, "K", -1, true); + const int64_t up_k = GetIntAttr(up_matmul, "K", -1, true); + const int64_t gate_n = GetIntAttr(gate_matmul, "N", -1, true); + const int64_t up_n = GetIntAttr(up_matmul, "N", -1, true); + const int64_t gate_bits = GetIntAttr(gate_matmul, "bits", 4); + const int64_t up_bits = GetIntAttr(up_matmul, "bits", 4); + const int64_t gate_block_size = GetIntAttr(gate_matmul, "block_size", -1, true); + const int64_t up_block_size = GetIntAttr(up_matmul, "block_size", -1, true); + const int64_t gate_accuracy_level = GetIntAttr(gate_matmul, "accuracy_level", 0); + const int64_t up_accuracy_level = GetIntAttr(up_matmul, "accuracy_level", 0); + + return gate_k == up_k && gate_n == up_n && gate_bits == up_bits && gate_block_size == up_block_size && + gate_accuracy_level == up_accuracy_level; +} + +} // namespace + +Status MatMulNBitsSiluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) { + auto* node_ptr = graph.GetNode(node_index); + if (node_ptr == nullptr) { + continue; + } + + auto& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (!IsSupportedMul(node)) { + continue; + } + + const auto& node_ep = node.GetExecutionProviderType(); + if (!node_ep.empty() && node_ep != kWebGpuExecutionProvider) { + continue; + } + + const Node* input0 = GetInputNode(graph, node, 0); + const Node* input1 = GetInputNode(graph, node, 1); + if (input0 == nullptr || input1 == nullptr) { + continue; + } + + const Node* silu_mul = nullptr; + const Node* up_matmul = nullptr; + if (IsSupportedMul(*input0) && IsMatMulNBitsWithoutZeroPointOrGroupIdx(*input1)) { + silu_mul = input0; + up_matmul = input1; + } else if (IsSupportedMul(*input1) && IsMatMulNBitsWithoutZeroPointOrGroupIdx(*input0)) { + silu_mul = input1; + up_matmul = input0; + } else { + continue; + } + + const Node* silu_input0 = GetInputNode(graph, *silu_mul, 0); + const Node* silu_input1 = GetInputNode(graph, *silu_mul, 1); + if (silu_input0 == nullptr || silu_input1 == nullptr) { + continue; + } + + const Node* gate_matmul = nullptr; + const Node* sigmoid = nullptr; + if (IsMatMulNBitsWithoutZeroPointOrGroupIdx(*silu_input0) && IsSupportedSigmoid(*silu_input1)) { + gate_matmul = silu_input0; + sigmoid = silu_input1; + } else if (IsMatMulNBitsWithoutZeroPointOrGroupIdx(*silu_input1) && IsSupportedSigmoid(*silu_input0)) { + gate_matmul = silu_input1; + sigmoid = silu_input0; + } else { + continue; + } + + if (!IsFuseCandidate(graph, *gate_matmul, *up_matmul, *sigmoid, *silu_mul, node)) { + continue; + } + + LOGS(logger, INFO) << "MatMulNBitsSiluFusion: matched candidate final_mul='" << node.Name() + << "' gate='" << gate_matmul->Name() << "' up='" << up_matmul->Name() + << "' sigmoid='" << sigmoid->Name() << "' silu_mul='" << silu_mul->Name() + << "' attrs={K=" << GetIntAttr(*gate_matmul, "K", -1, true) + << ", N=" << GetIntAttr(*gate_matmul, "N", -1, true) + << ", bits=" << GetIntAttr(*gate_matmul, "bits", 4) + << ", block_size=" << GetIntAttr(*gate_matmul, "block_size", -1, true) + << ", accuracy_level=" << GetIntAttr(*gate_matmul, "accuracy_level", 0) + << "}"; + + LOGS(logger, INFO) << "MatMulNBitsSiluFusion: EP state final_mul='" << node.GetExecutionProviderType() + << "' gate='" << gate_matmul->GetExecutionProviderType() + << "' up='" << up_matmul->GetExecutionProviderType() + << "' sigmoid='" << sigmoid->GetExecutionProviderType() + << "' silu_mul='" << silu_mul->GetExecutionProviderType() << "'"; + + if ((!gate_matmul->GetExecutionProviderType().empty() && gate_matmul->GetExecutionProviderType() != kWebGpuExecutionProvider) || + (!up_matmul->GetExecutionProviderType().empty() && up_matmul->GetExecutionProviderType() != kWebGpuExecutionProvider) || + (!sigmoid->GetExecutionProviderType().empty() && sigmoid->GetExecutionProviderType() != kWebGpuExecutionProvider) || + (!silu_mul->GetExecutionProviderType().empty() && silu_mul->GetExecutionProviderType() != kWebGpuExecutionProvider)) { + LOGS(logger, INFO) << "MatMulNBitsSiluFusion: skipping candidate due to non-WebGPU EP assignment."; + continue; + } + + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("K", GetIntAttr(*gate_matmul, "K", -1, true)), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("N", GetIntAttr(*gate_matmul, "N", -1, true)), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("bits", GetIntAttr(*gate_matmul, "bits", 4)), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", GetIntAttr(*gate_matmul, "block_size", -1, true)), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", GetIntAttr(*gate_matmul, "accuracy_level", 0)), attrs); + + NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr); + + InlinedVector fused_inputs{ + const_cast(gate_matmul->InputDefs()[0]), + const_cast(gate_matmul->InputDefs()[1]), + const_cast(gate_matmul->InputDefs()[2]), + HasInput(*gate_matmul, 5) ? const_cast(gate_matmul->InputDefs()[5]) : &empty_arg, + const_cast(up_matmul->InputDefs()[1]), + const_cast(up_matmul->InputDefs()[2]), + HasInput(*up_matmul, 5) ? const_cast(up_matmul->InputDefs()[5]) : &empty_arg, + }; + + Node& fused_node = graph.AddNode(graph.GenerateNodeName("MatMulNBitsSiluMul"), + "MatMulNBitsSiluMul", + "fused MatMulNBits gate/up projections with SiLU multiply", + fused_inputs, + {}, + &attrs, + kMSDomain); + fused_node.SetExecutionProviderType(kWebGpuExecutionProvider); + + LOGS(logger, INFO) << "MatMulNBitsSiluFusion: created fused node '" << fused_node.Name() + << "' from final_mul='" << node.Name() << "'"; + + graph_utils::FinalizeNodeFusion(graph, + {std::ref(const_cast(*gate_matmul)), + std::ref(const_cast(*up_matmul)), + std::ref(const_cast(*sigmoid)), + std::ref(const_cast(*silu_mul)), + std::ref(node)}, + fused_node); + + modified = true; + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/matmul_nbits_silu_fusion.h b/onnxruntime/core/optimizer/matmul_nbits_silu_fusion.h new file mode 100644 index 0000000000000..d2c84dc2a3983 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_nbits_silu_fusion.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +class MatMulNBitsSiluFusion : public GraphTransformer { + public: + explicit MatMulNBitsSiluFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("MatMulNBitsSiluFusion", compatible_execution_providers) {} + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc index 9cfda7033e8d2..901fc97c5c517 100644 --- a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc +++ b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc @@ -54,6 +54,19 @@ struct DecodeBenchConfig { int64_t accuracy_level; }; +enum class MlpDecodeBenchmarkVariant { + kUnfused, + kFused, +}; + +struct MlpDecodeBenchConfig { + int64_t n; + int64_t k; + int64_t bits; + int64_t block_size; + int64_t accuracy_level; +}; + struct AdapterSelectionConfig { // adapter_type: Dawn adapter type to select, e.g. integrated or discrete GPU. // preferred_vendor_id/device_id: stable PCI identifiers used to locate the target GPU regardless of enumeration order. @@ -101,6 +114,15 @@ struct DecodeTrafficStats { double total_bytes; }; +struct MlpTrafficStats { + double input_bytes; + double packed_weight_bytes; + double scale_bytes; + double intermediate_bytes; + double output_bytes; + double total_bytes; +}; + constexpr double kRtx5060TiTheoreticalBandwidthBytesPerSecond = 448.0 * 1000.0 * 1000.0 * 1000.0; constexpr int kDecodeWarmupRuns = 25; @@ -271,6 +293,29 @@ DecodeTrafficStats CalculateDecodeTrafficStats(const DecodeBenchConfig& config) }; } + MlpTrafficStats CalculateMlpTrafficStats(const MlpDecodeBenchConfig& config, MlpDecodeBenchmarkVariant variant) { + const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; + const int64_t blob_size = (config.block_size * config.bits) / 8; + + const double input_reads = variant == MlpDecodeBenchmarkVariant::kFused ? 1.0 : 2.0; + const double intermediate_bytes = + variant == MlpDecodeBenchmarkVariant::kFused ? 0.0 : 4.0 * static_cast(config.n) * sizeof(Ort::Float16_t); + const double input_bytes = input_reads * static_cast(config.k) * sizeof(Ort::Float16_t); + const double packed_weight_bytes = + 2.0 * static_cast(config.n) * static_cast(k_blocks) * static_cast(blob_size); + const double scale_bytes = 2.0 * static_cast(config.n) * static_cast(k_blocks) * sizeof(Ort::Float16_t); + const double output_bytes = static_cast(config.n) * sizeof(Ort::Float16_t); + + return { + input_bytes, + packed_weight_bytes, + scale_bytes, + intermediate_bytes, + output_bytes, + input_bytes + packed_weight_bytes + scale_bytes + intermediate_bytes + output_bytes, + }; + } + AdapterSelectionConfig GetAdapterSelectionConfig() { if (GetDecodeBenchmarkGpu() == DecodeBenchmarkGpu::kT1000) { return { @@ -482,6 +527,19 @@ void AddTensorInitializer(ONNX_NAMESPACE::GraphProto& graph, initializer->set_raw_data(values.data(), values.size() * sizeof(T)); } +void AddTensorValueInfo(ONNX_NAMESPACE::GraphProto& graph, + const std::string& name, + int32_t data_type, + const std::vector& dims) { + auto* value_info = graph.add_value_info(); + value_info->set_name(name); + value_info->mutable_type()->mutable_tensor_type()->set_elem_type(data_type); + auto* shape = value_info->mutable_type()->mutable_tensor_type()->mutable_shape(); + for (int64_t dim : dims) { + shape->add_dim()->set_dim_value(dim); + } +} + std::vector GetDecodeBenchConfigs() { // Each entry is {N, K, bits, block_size, accuracy_level} for a decode-style M=1 run. return { @@ -500,6 +558,15 @@ std::vector GetDecodeBenchConfigs() { }; } +std::vector GetMlpDecodeBenchConfigs() { + // Each entry is {N, K, bits, block_size, accuracy_level} for a decode-style M=1 MLP run. + return { + {6144, 2048, 4, 32, 4}, + {8192, 3072, 4, 32, 4}, + {11008, 4096, 4, 32, 4}, + }; +} + void AddMatMulNBitsNode(ONNX_NAMESPACE::GraphProto& graph, const std::string& node_name, const std::string& input_name, @@ -548,6 +615,58 @@ void AddMatMulNBitsNode(ONNX_NAMESPACE::GraphProto& graph, attr_accuracy->set_i(accuracy_level); } +void AddMatMulNBitsSiluMulNode(ONNX_NAMESPACE::GraphProto& graph, + const std::string& node_name, + const std::string& input_name, + const std::string& gate_weight_name, + const std::string& gate_scale_name, + const std::string& up_weight_name, + const std::string& up_scale_name, + const std::string& output_name, + int64_t k, + int64_t n, + int64_t bits, + int64_t block_size, + int64_t accuracy_level) { + auto* node = graph.add_node(); + node->set_name(node_name); + node->set_op_type("MatMulNBitsSiluMul"); + node->set_domain("com.microsoft"); + node->add_input(input_name); + node->add_input(gate_weight_name); + node->add_input(gate_scale_name); + node->add_input(""); + node->add_input(up_weight_name); + node->add_input(up_scale_name); + node->add_input(""); + node->add_output(output_name); + + auto* attr_k = node->add_attribute(); + attr_k->set_name("K"); + attr_k->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + attr_k->set_i(k); + + auto* attr_n = node->add_attribute(); + attr_n->set_name("N"); + attr_n->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + attr_n->set_i(n); + + auto* attr_bits = node->add_attribute(); + attr_bits->set_name("bits"); + attr_bits->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + attr_bits->set_i(bits); + + auto* attr_block = node->add_attribute(); + attr_block->set_name("block_size"); + attr_block->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + attr_block->set_i(block_size); + + auto* attr_accuracy = node->add_attribute(); + attr_accuracy->set_name("accuracy_level"); + attr_accuracy->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + attr_accuracy->set_i(accuracy_level); +} + std::vector SerializeMatMulNBitsModel(const DecodeBenchConfig& config) { const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; const int64_t blob_size = (config.block_size * config.bits) / 8; @@ -601,10 +720,138 @@ std::vector SerializeMatMulNBitsModel(const DecodeBenchConfig& config) return std::vector(serialized.begin(), serialized.end()); } +std::string GetMlpVariantLabel(MlpDecodeBenchmarkVariant variant) { + return variant == MlpDecodeBenchmarkVariant::kFused ? "fused" : "unfused"; +} + +std::string GetMlpDecodeBenchmarkLabel(MlpDecodeBenchmarkVariant variant) { + std::ostringstream stream; + stream << "fp16_mlp_decode_" << GetMlpVariantLabel(variant) << '_' + << (IsDecodeBenchmarkPerfMode() ? "perf" : "correctness") << '_' + << (GetDecodeBenchmarkGpu() == DecodeBenchmarkGpu::kRtx5060Ti ? "rtx" : "t") << '_' + << (IsMatMulNBitsAutoTunerEnabled() ? "tuner_on" : "tuner_off"); + return stream.str(); +} + +std::vector SerializeMatMulNBitsMlpModel(const MlpDecodeBenchConfig& config, + MlpDecodeBenchmarkVariant variant) { + const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; + const int64_t blob_size = (config.block_size * config.bits) / 8; + + ONNX_NAMESPACE::ModelProto model; + model.set_ir_version(10); + + auto* onnx_opset = model.add_opset_import(); + onnx_opset->set_domain(""); + onnx_opset->set_version(21); + auto* ms_opset = model.add_opset_import(); + ms_opset->set_domain("com.microsoft"); + ms_opset->set_version(1); + + auto* graph = model.mutable_graph(); + graph->set_name(variant == MlpDecodeBenchmarkVariant::kFused ? "WebGpuMatMulNBitsMlpDecodeFused" + : "WebGpuMatMulNBitsMlpDecodeUnfused"); + + auto* input = graph->add_input(); + input->set_name("A"); + input->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.k); + + auto* output = graph->add_output(); + output->set_name("Y"); + output->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.n); + + std::vector gate_b(static_cast(config.n * k_blocks * blob_size), uint8_t{0x11}); + std::vector up_b(static_cast(config.n * k_blocks * blob_size), uint8_t{0x77}); + std::vector gate_scales(static_cast(config.n * k_blocks), Ort::Float16_t(0.03125f)); + std::vector up_scales(static_cast(config.n * k_blocks), Ort::Float16_t(0.0625f)); + + AddTensorInitializer(*graph, "gate_B", ONNX_NAMESPACE::TensorProto_DataType_UINT8, + {config.n, k_blocks, blob_size}, gate_b); + AddTensorInitializer(*graph, "up_B", ONNX_NAMESPACE::TensorProto_DataType_UINT8, + {config.n, k_blocks, blob_size}, up_b); + AddTensorInitializer(*graph, "gate_scales", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + {config.n, k_blocks}, gate_scales); + AddTensorInitializer(*graph, "up_scales", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + {config.n, k_blocks}, up_scales); + + if (variant == MlpDecodeBenchmarkVariant::kFused) { + AddMatMulNBitsSiluMulNode(*graph, + "MatMulNBitsSiluMulDecode", + "A", + "gate_B", + "gate_scales", + "up_B", + "up_scales", + "Y", + config.k, + config.n, + config.bits, + config.block_size, + config.accuracy_level); + } else { + AddTensorValueInfo(*graph, "gate_mm", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.n}); + AddTensorValueInfo(*graph, "up_mm", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.n}); + AddTensorValueInfo(*graph, "gate_sigmoid", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.n}); + AddTensorValueInfo(*graph, "gate_silu", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.n}); + + AddMatMulNBitsNode(*graph, + "GateMatMulNBitsDecode", + "A", + "gate_B", + "gate_scales", + "gate_mm", + config.k, + config.n, + config.bits, + config.block_size, + config.accuracy_level); + AddMatMulNBitsNode(*graph, + "UpMatMulNBitsDecode", + "A", + "up_B", + "up_scales", + "up_mm", + config.k, + config.n, + config.bits, + config.block_size, + config.accuracy_level); + + auto* sigmoid = graph->add_node(); + sigmoid->set_name("GateSigmoid"); + sigmoid->set_op_type("Sigmoid"); + sigmoid->add_input("gate_mm"); + sigmoid->add_output("gate_sigmoid"); + + auto* silu_mul = graph->add_node(); + silu_mul->set_name("GateSiluMul"); + silu_mul->set_op_type("Mul"); + silu_mul->add_input("gate_mm"); + silu_mul->add_input("gate_sigmoid"); + silu_mul->add_output("gate_silu"); + + auto* output_mul = graph->add_node(); + output_mul->set_name("GateUpMul"); + output_mul->set_op_type("Mul"); + output_mul->add_input("gate_silu"); + output_mul->add_input("up_mm"); + output_mul->add_output("Y"); + } + + const auto serialized = model.SerializeAsString(); + return std::vector(serialized.begin(), serialized.end()); +} + Ort::Session CreateSessionFromModelData(const std::vector& model_data, - const std::unordered_map* provider_options) { + const std::unordered_map* provider_options, + GraphOptimizationLevel graph_optimization_level = GraphOptimizationLevel::ORT_ENABLE_ALL) { Ort::SessionOptions session_options; session_options.DisableMemPattern(); + session_options.SetGraphOptimizationLevel(graph_optimization_level); if (provider_options != nullptr) { session_options.AppendExecutionProvider("WebGPU", *provider_options); } @@ -670,6 +917,62 @@ void ValidateDecodeOutputs(const std::vector& model_data, << " at index " << max_abs_diff_index << std::endl; } +void ValidateMlpDecodeOutputs(const std::vector& unfused_model_data, + const std::vector& fused_model_data, + const std::unordered_map& provider_options, + const char* const* input_names, + const Ort::Value* input_tensor, + const char* const* output_names) { + Ort::Session unfused_session = CreateSessionFromModelData(unfused_model_data, + &provider_options, + GraphOptimizationLevel::ORT_DISABLE_ALL); + Ort::Session fused_session = CreateSessionFromModelData(fused_model_data, + &provider_options, + GraphOptimizationLevel::ORT_ENABLE_ALL); + + auto unfused_outputs = unfused_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensor, 1, output_names, 1); + auto fused_outputs = fused_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensor, 1, output_names, 1); + + if (unfused_outputs.size() != 1 || fused_outputs.size() != 1) { + throw std::runtime_error("Expected a single output from both unfused and fused MLP sessions."); + } + + const auto& unfused_output = unfused_outputs[0]; + const auto& fused_output = fused_outputs[0]; + const size_t element_count = unfused_output.GetTensorTypeAndShapeInfo().GetElementCount(); + if (element_count != fused_output.GetTensorTypeAndShapeInfo().GetElementCount()) { + throw std::runtime_error("Unfused and fused MLP output sizes do not match."); + } + + const auto* unfused_data = unfused_output.GetTensorData(); + const auto* fused_data = fused_output.GetTensorData(); + float max_abs_diff = 0.0f; + size_t max_abs_diff_index = 0; + for (size_t i = 0; i < element_count; ++i) { + const float unfused_value = unfused_data[i].ToFloat(); + const float fused_value = fused_data[i].ToFloat(); + const float abs_diff = std::abs(unfused_value - fused_value); + const float allowed_diff = kDecodeCorrectnessAbsTolerance + + kDecodeCorrectnessRelTolerance * std::max(std::abs(unfused_value), std::abs(fused_value)); + if (abs_diff > max_abs_diff) { + max_abs_diff = abs_diff; + max_abs_diff_index = i; + } + if (abs_diff > allowed_diff) { + std::ostringstream stream; + stream << "MLP decode correctness check failed at index " << i + << ": unfused=" << unfused_value + << ", fused=" << fused_value + << ", abs_diff=" << abs_diff + << ", allowed_diff=" << allowed_diff; + throw std::runtime_error(stream.str()); + } + } + + std::cout << "MLP decode correctness check passed. max_abs_diff=" << max_abs_diff + << " at index " << max_abs_diff_index << std::endl; +} + static void BM_WebGpuMatMulNBitsDecode(benchmark::State& state) { try { const DecodeBenchConfig config{ @@ -752,12 +1055,114 @@ static void BM_WebGpuMatMulNBitsDecode(benchmark::State& state) { } } +void BenchmarkWebGpuMatMulNBitsMlpDecode(benchmark::State& state, MlpDecodeBenchmarkVariant variant) { + try { + const MlpDecodeBenchConfig config{ + state.range(0), + state.range(1), + state.range(2), + state.range(3), + state.range(4), + }; + + if (config.k % config.block_size != 0) { + state.SkipWithError("K must be divisible by block_size for this benchmark skeleton."); + return; + } + + const MlpTrafficStats traffic = CalculateMlpTrafficStats(config, variant); + std::vector model_data = SerializeMatMulNBitsMlpModel(config, variant); + const SelectedWebGpuContext& selected_context = GetSelectedWebGpuContext(); + const GraphOptimizationLevel optimization_level = + variant == MlpDecodeBenchmarkVariant::kUnfused ? GraphOptimizationLevel::ORT_DISABLE_ALL + : GraphOptimizationLevel::ORT_ENABLE_ALL; + Ort::Session session = CreateSessionFromModelData(model_data, + &selected_context.provider_options, + optimization_level); + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + std::vector input_shape{1, config.k}; + std::vector activation(static_cast(config.k)); + + std::mt19937 rng(123); + std::uniform_real_distribution dist(-1.0f, 1.0f); + for (auto& value : activation) { + value = Ort::Float16_t(dist(rng)); + } + + const char* input_names[] = {"A"}; + const char* output_names[] = {"Y"}; + + auto input_tensor = Ort::Value::CreateTensor(memory_info, + activation.data(), + activation.size(), + input_shape.data(), + input_shape.size()); + + if (!IsDecodeBenchmarkPerfMode()) { + ValidateMlpDecodeOutputs(SerializeMatMulNBitsMlpModel(config, MlpDecodeBenchmarkVariant::kUnfused), + SerializeMatMulNBitsMlpModel(config, MlpDecodeBenchmarkVariant::kFused), + selected_context.provider_options, + input_names, + &input_tensor, + output_names); + } + + for (int i = 0; i < kDecodeWarmupRuns; ++i) { + auto warmup_outputs = session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, 1); + benchmark::DoNotOptimize(warmup_outputs); + } + + double total_kernel_seconds = 0.0; + for (auto _ : state) { + const auto kernel_start = std::chrono::steady_clock::now(); + auto outputs = session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, 1); + const auto kernel_end = std::chrono::steady_clock::now(); + total_kernel_seconds += std::chrono::duration(kernel_end - kernel_start).count(); + benchmark::DoNotOptimize(outputs); + } + + const double total_flops = 4.0 * static_cast(config.n) * static_cast(config.k); + const double achieved_bandwidth_bytes_per_second = + total_kernel_seconds > 0.0 + ? traffic.total_bytes * static_cast(state.iterations()) / total_kernel_seconds + : 0.0; + + state.SetLabel(GetMlpDecodeBenchmarkLabel(variant)); + state.counters["TFLOPS"] = benchmark::Counter( + total_flops, + benchmark::Counter::kIsIterationInvariantRate); + state.counters["ApproxMemBW_GBps"] = benchmark::Counter(achieved_bandwidth_bytes_per_second / 1.0e9); + state.counters["ApproxTraffic_MB"] = benchmark::Counter(traffic.total_bytes / 1.0e6); + state.counters["Input_MB"] = benchmark::Counter(traffic.input_bytes / 1.0e6); + state.counters["PackedW_MB"] = benchmark::Counter(traffic.packed_weight_bytes / 1.0e6); + state.counters["Scales_MB"] = benchmark::Counter(traffic.scale_bytes / 1.0e6); + state.counters["Intermediate_MB"] = benchmark::Counter(traffic.intermediate_bytes / 1.0e6); + state.counters["Output_MB"] = benchmark::Counter(traffic.output_bytes / 1.0e6); + } catch (const std::exception& ex) { + state.SkipWithError(ex.what()); + } +} + +static void BM_WebGpuMatMulNBitsMlpDecodeUnfused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kUnfused); +} + +static void BM_WebGpuMatMulNBitsMlpDecodeFused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kFused); +} + void ApplyWebGpuMatMulNBitsDecodeArgs(benchmark::internal::Benchmark* benchmark) { for (const auto& config : GetDecodeBenchConfigs()) { benchmark->Args({config.n, config.k, config.bits, config.block_size, config.accuracy_level}); } } +void ApplyWebGpuMatMulNBitsMlpDecodeArgs(benchmark::internal::Benchmark* benchmark) { + for (const auto& config : GetMlpDecodeBenchConfigs()) { + benchmark->Args({config.n, config.k, config.bits, config.block_size, config.accuracy_level}); + } +} + BENCHMARK(BM_WebGpuMatMulNBitsDecode) ->Apply(ApplyWebGpuMatMulNBitsDecodeArgs) ->Repetitions(5) @@ -765,4 +1170,18 @@ BENCHMARK(BM_WebGpuMatMulNBitsDecode) ->UseRealTime() ->Unit(benchmark::TimeUnit::kMicrosecond); +BENCHMARK(BM_WebGpuMatMulNBitsMlpDecodeUnfused) + ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) + ->Repetitions(5) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +BENCHMARK(BM_WebGpuMatMulNBitsMlpDecodeFused) + ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) + ->Repetitions(5) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + } // namespace diff --git a/onnxruntime/test/optimizer/matmul_nbits_silu_fusion_test.cc b/onnxruntime/test/optimizer/matmul_nbits_silu_fusion_test.cc new file mode 100644 index 0000000000000..6d48344e27faf --- /dev/null +++ b/onnxruntime/test/optimizer/matmul_nbits_silu_fusion_test.cc @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/node_attr_utils.h" +#include "core/optimizer/graph_transformer_mgr.h" +#include "core/optimizer/matmul_nbits_silu_fusion.h" +#include "core/optimizer/utils.h" + +#include "test/util/include/asserts.h" +#include "test/util/include/default_providers.h" +#include "test/unittest_util/framework_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" +#include "test/optimizer/graph_transform_test_fixture.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +#if !defined(DISABLE_CONTRIB_OPS) + +namespace { + +void SetWebGpuProvider(Node& node) { + node.SetExecutionProviderType(kWebGpuExecutionProvider); +} + +NodeAttributes MakeMatMulNBitsAttrs(int64_t k, int64_t n, int64_t block_size, int64_t bits, int64_t accuracy_level) { + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("K", k), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("N", n), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("bits", bits), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level), attrs); + return attrs; +} + +Status CheckMatMulNBitsSiluFusedGraph(const Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + if (OpCount(op_to_count, "com.microsoft.MatMulNBitsSiluMul") != 1 || + OpCount(op_to_count, "com.microsoft.MatMulNBits") != 0 || + OpCount(op_to_count, "Sigmoid") != 0 || + OpCount(op_to_count, "Mul") != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected operator counts after MatMulNBitsSiluFusion."); + } + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBitsSiluMul") { + ORT_RETURN_IF_NOT(node.Domain() == kMSDomain, "Fused node must be in com.microsoft domain."); + ORT_RETURN_IF_NOT(node.GetExecutionProviderType() == kWebGpuExecutionProvider, + "Fused node must be assigned to WebGPU EP."); + } + } + + return Status::OK(); +} + +void BuildMatMulNBitsSiluWebGpuPattern(ModelTestBuilder& builder) { + constexpr int64_t k = 16; + constexpr int64_t n = 8; + constexpr int64_t block_size = 16; + constexpr int64_t bits = 4; + constexpr int64_t accuracy_level = 4; + constexpr int64_t blob_size = block_size * bits / 8; + + NodeArg* input = builder.MakeInput( + std::vector{1, k}, + std::vector{ + MLFloat16(-1.0f), MLFloat16(-0.875f), MLFloat16(-0.75f), MLFloat16(-0.625f), + MLFloat16(-0.5f), MLFloat16(-0.375f), MLFloat16(-0.25f), MLFloat16(-0.125f), + MLFloat16(0.125f), MLFloat16(0.25f), MLFloat16(0.375f), MLFloat16(0.5f), + MLFloat16(0.625f), MLFloat16(0.75f), MLFloat16(0.875f), MLFloat16(1.0f)}); + NodeArg* optional_tensor = builder.MakeOptionalTensor(); + + NodeArg* gate_weight = builder.MakeInitializer({n, 1, blob_size}, uint8_t{0}, uint8_t{15}); + NodeArg* gate_scale = builder.MakeInitializer({n, 1}, MLFloat16(1.0f), MLFloat16(1.0f)); + NodeArg* gate_bias = builder.MakeInitializer({n}, MLFloat16(0.0f), MLFloat16(0.0f)); + NodeArg* up_weight = builder.MakeInitializer({n, 1, blob_size}, uint8_t{0}, uint8_t{15}); + NodeArg* up_scale = builder.MakeInitializer({n, 1}, MLFloat16(1.0f), MLFloat16(1.0f)); + NodeArg* up_bias = builder.MakeInitializer({n}, MLFloat16(0.0f), MLFloat16(0.0f)); + + NodeArg* gate_out = builder.MakeIntermediate(std::vector{1, n}); + NodeArg* up_out = builder.MakeIntermediate(std::vector{1, n}); + NodeArg* sigmoid_out = builder.MakeIntermediate(std::vector{1, n}); + NodeArg* silu_out = builder.MakeIntermediate(std::vector{1, n}); + NodeArg* output = builder.MakeOutput(std::vector{1, n}); + + NodeAttributes matmul_attrs = MakeMatMulNBitsAttrs(k, n, block_size, bits, accuracy_level); + Node& gate_matmul = builder.AddNode("MatMulNBits", {input, gate_weight, gate_scale, optional_tensor, optional_tensor, gate_bias}, {gate_out}, kMSDomain, &matmul_attrs); + Node& up_matmul = builder.AddNode("MatMulNBits", {input, up_weight, up_scale, optional_tensor, optional_tensor, up_bias}, {up_out}, kMSDomain, &matmul_attrs); + Node& sigmoid = builder.AddNode("Sigmoid", {gate_out}, {sigmoid_out}); + Node& silu_mul = builder.AddNode("Mul", {gate_out, sigmoid_out}, {silu_out}); + Node& final_mul = builder.AddNode("Mul", {silu_out, up_out}, {output}); + + SetWebGpuProvider(gate_matmul); + SetWebGpuProvider(up_matmul); + SetWebGpuProvider(sigmoid); + SetWebGpuProvider(silu_mul); + SetWebGpuProvider(final_mul); +} + +} // namespace + +TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionFusesWebGpuPattern) { + ASSERT_STATUS_OK(TestGraphTransformer( + BuildMatMulNBitsSiluWebGpuPattern, + 21, + *logger_, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + TransformerLevel::Level2, + 1, + nullptr, + CheckMatMulNBitsSiluFusedGraph)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionMatchesUnfusedWebGpuResults) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsSiluFusedGraph(session.GetGraph())); + }; + + TransformerTester( + BuildMatMulNBitsSiluWebGpuPattern, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + {}, + {}, + std::move(webgpu_ep)); +} + +#endif // !defined(DISABLE_CONTRIB_OPS) + +} // namespace test +} // namespace onnxruntime From 318b26be457cfe79089061297eae1191cc0025ed Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 20 Apr 2026 14:27:24 -0700 Subject: [PATCH 05/26] Stage --- .../dp4a_matmul_silu_mul.wgsl.template | 110 ++++++ .../webgpu/quantization/matmul_nbits.cc | 85 ----- .../webgpu/quantization/matmul_nbits.h | 34 -- .../matmul_nbits_m1.wgsl.template | 161 --------- .../webgpu/quantization/matmul_nbits_silu.cc | 326 ++++++++++++++---- .../core/optimizer/graph_transformer_utils.cc | 19 +- .../webgpu_matmul_nbits_decode.cc | 47 ++- .../optimizer/graph_transform_utils_test.cc | 43 +++ 8 files changed, 456 insertions(+), 369 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_silu_mul.wgsl.template delete mode 100644 onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_m1.wgsl.template diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_silu_mul.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_silu_mul.wgsl.template new file mode 100644 index 0000000000000..77dc522130f82 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_silu_mul.wgsl.template @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param tile_size +#param tile_size_k_vec +#param single_scale_weights +#param has_gate_bias +#param has_up_bias + +#use .getByOffset .setByOffset + +#include "quantization/dp4a_matmul_common.wgsl.template" + +const double_tile_size_k_vec = 2 * tile_size_k_vec; +const scale_a_size_in_tile_a = double_tile_size_k_vec / 8; + +var gate_inter_results: array, tile_size>; +var up_inter_results: array, tile_size>; +var tile_A: array, double_tile_size_k_vec>; +var scale_A: array; + +fn loadSHMA(batch: u32, kidx_v: u32, col: u32) { + let k_offset = kidx_v + col; + if (k_offset >= uniforms.K16) { + return; + } + + tile_A[col] = a.getByOffset(batch * uniforms.K16 + k_offset); + if (col < scale_a_size_in_tile_a) { + scale_A[col] = scales_a.getByOffset(batch * (uniforms.K / 128) + kidx_v / 8 + col); + } +} + +$MAIN { + let batch = workgroup_id.z; + if (batch >= uniforms.batch_count) { + return; + } + + let b_global_base = workgroup_id.x * tile_size; + let local_col = local_idx % tile_size_k_vec; + let local_row = local_idx / tile_size_k_vec; + + if (local_idx < tile_size) { + for (var lane = 0u; lane < tile_size_k_vec; lane++) { + gate_inter_results[local_idx][lane] = output_element_t(0); + up_inter_results[local_idx][lane] = output_element_t(0); + } + } + workgroupBarrier(); + +#if single_scale_weights + let gate_scale_b = gate_scales_b.getByOffset(0); + let up_scale_b = up_scales_b.getByOffset(0); +#endif + + for (var kidx_v: u32 = 0u; kidx_v < uniforms.K32; kidx_v += tile_size_k_vec) { + if (local_idx < double_tile_size_k_vec) { + loadSHMA(batch, kidx_v * 2u, local_idx); + } + workgroupBarrier(); + + let own_a0 = tile_A[local_col * 2u]; + let own_a1 = tile_A[local_col * 2u + 1u]; + let own_scale_a = scale_A[local_col / 4u]; + let k_offset = kidx_v + local_col; + let block_idx = k_offset * 32u / uniforms.block_size; + + let b_global = b_global_base + local_row; + if (b_global < uniforms.N && k_offset < uniforms.K32) { +#if !single_scale_weights + let gate_scale_b = gate_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx); + let up_scale_b = up_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx); +#endif + let gate_b_value = gate_b.getByOffset(b_global * uniforms.K32 + k_offset); + let up_b_value = up_b.getByOffset(b_global * uniforms.K32 + k_offset); + let gate_b0 = DequantizedFrom4BitsTo8Bits(gate_b_value.xy, default_zero_point); + let gate_b1 = DequantizedFrom4BitsTo8Bits(gate_b_value.zw, default_zero_point); + let up_b0 = DequantizedFrom4BitsTo8Bits(up_b_value.xy, default_zero_point); + let up_b1 = DequantizedFrom4BitsTo8Bits(up_b_value.zw, default_zero_point); + let gate_scale = own_scale_a * gate_scale_b; + let up_scale = own_scale_a * up_scale_b; + gate_inter_results[local_row][local_col] += SDP8AI(own_a0, gate_b0, own_a1, gate_b1, gate_scale); + up_inter_results[local_row][local_col] += SDP8AI(own_a0, up_b0, own_a1, up_b1, up_scale); + } + workgroupBarrier(); + } + + if (local_idx < tile_size) { + var gate_output_value = output_element_t(0); + var up_output_value = output_element_t(0); + for (var lane = 0u; lane < tile_size_k_vec; lane++) { + gate_output_value += gate_inter_results[local_idx][lane]; + up_output_value += up_inter_results[local_idx][lane]; + } + + let b_global = b_global_base + local_idx; + if (b_global < uniforms.N) { +#if has_gate_bias + gate_output_value += gate_bias[b_global]; +#endif +#if has_up_bias + up_output_value += up_bias[b_global]; +#endif + let one = output_element_t(1.0); + let silu_value = gate_output_value * (one / (one + exp(-gate_output_value))); + output.setByOffset(batch * uniforms.N + b_global, silu_value * up_output_value); + } + } +} diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index afdc49d67765b..0db99c816dc29 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -68,43 +68,6 @@ Status MatMulNBitsWideTileProgram::GenerateShaderCode(ShaderHelper& shader) cons WGSL_TEMPLATE_VARIABLE(scales, scales)); } -Status MatMulNBitsM1Program::GenerateShaderCode(ShaderHelper& shader) const { - const auto& a = shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias); - const auto& b = shader.AddInput("input_b"); - const auto& scales_b = shader.AddInput("scales_b"); - if (has_bias_) { - shader.AddInput("bias", ShaderUsage::UseUniform); - } - const auto& output = shader.AddOutput("output", ShaderUsage::UseElementTypeAlias); - - const uint32_t components_a = a.NumComponents(); - const uint32_t components_b = b.NumComponents() / 4; - const uint32_t elements_in_value_b = components_b * 8u; - const uint32_t a_length_per_lane = elements_in_value_b / components_a; - const uint32_t tile_size_k = tile_size_k_vec_ * elements_in_value_b; - ORT_ENFORCE(tile_size_ % 4u == 0u, "tile_size must be divisible by 4 for MatMulNBitsM1Program."); - - return WGSL_TEMPLATE_APPLY(shader, - "quantization/matmul_nbits_m1.wgsl.template", - WGSL_TEMPLATE_PARAMETER(a_length_per_lane, a_length_per_lane), - WGSL_TEMPLATE_PARAMETER(component_a, components_a), - WGSL_TEMPLATE_PARAMETER(component_b, components_b), - WGSL_TEMPLATE_PARAMETER(elements_in_value_b, elements_in_value_b), - WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_), - WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx_), - WGSL_TEMPLATE_PARAMETER(has_zero_points, false), - WGSL_TEMPLATE_PARAMETER(n_bits, 4), - WGSL_TEMPLATE_PARAMETER(output_type_i32, false), - WGSL_TEMPLATE_PARAMETER(single_scale_weights, single_scale_weights_), - WGSL_TEMPLATE_PARAMETER(subgroup_tile_size, tile_size_ / 4u), - WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k), - WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec_), - WGSL_TEMPLATE_VARIABLE(a, a), - WGSL_TEMPLATE_VARIABLE(b, b), - WGSL_TEMPLATE_VARIABLE(output, output), - WGSL_TEMPLATE_VARIABLE(scales_b, scales_b)); -} - // Apply similar idea with DP4AMatMulNBitsSmallMProgram algorithm. Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& a = shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias); @@ -339,54 +302,6 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, return context.RunProgram(program); } - // Disabled for now so decode uses the existing generic fallback path that is - // already deployed in production. We can reintroduce M1-specific ideas when - // tuning the fused implementation. - // const bool use_m1_subgroup_program = M == 1 && - // batch_count >= 1 && - // static_cast(nbits) == 4u && - // !has_zero_points && - // !has_weight_idx_indirect && - // context.AdapterInfo().vendor == std::string_view{"nvidia"} && - // context.HasFeature(wgpu::FeatureName::Subgroups); - // - // if (use_m1_subgroup_program) { - // constexpr uint32_t workgroup_size = 128; - // constexpr uint32_t tile_size = 8; - // constexpr uint32_t tile_size_k_vec = 32; - // constexpr uint32_t kU32Components = 4; - // const uint32_t components_b_with_u32 = components_b * kU32Components; - // const uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; - // - // MatMulNBitsM1Program program{tile_size, - // has_bias, - // has_weight_idx, - // single_scale_weights, - // tile_size_k_vec}; - // program.SetWorkgroupSize(workgroup_size); - // const uint32_t num_N_tile = CeilDiv(N, tile_size); - // program.SetDispatchGroupSize(num_N_tile, 1, batch_count); - // program - // .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}, - // {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, - // {scales, ProgramTensorMetadataDependency::TypeAndRank}}) - // .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank}) - // .AddUniformVariables({{N}, - // {K}, - // {K / components_a}, - // {K_of_b}, - // {block_size}, - // {n_blocks_per_col}, - // {num_N_tile}, - // {batch_count}, - // {weight_index}}) - // .CacheHint(has_bias, has_weight_idx, single_scale_weights, tile_size_k_vec); - // if (has_bias) { - // program.AddInput({bias, ProgramTensorMetadataDependency::None}); - // } - // return context.RunProgram(program); - // } - // Use tile_size_k_vec=32 by default for better K-dimension parallelism. // Intel devices use 16 as they have different subgroup/cache characteristics. const uint32_t tile_size_k_vec = diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index a009a7f757b55..295a0fb90dd2a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -39,40 +39,6 @@ class MatMulNBitsWideTileProgram final : public Program { - public: - MatMulNBitsM1Program(uint32_t tile_size, - bool has_bias, - bool has_weight_idx, - bool single_scale_weights, - uint32_t tile_size_k_vec = 32) - : Program{"MatMulNBitsM1"}, - tile_size_(tile_size), - has_bias_(has_bias), - has_weight_idx_{has_weight_idx}, - single_scale_weights_(single_scale_weights), - tile_size_k_vec_(tile_size_k_vec) {} - - Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( - {"N", ProgramUniformVariableDataType::Uint32}, - {"K", ProgramUniformVariableDataType::Uint32}, - {"K_of_a", ProgramUniformVariableDataType::Uint32}, - {"K_of_b", ProgramUniformVariableDataType::Uint32}, - {"block_size", ProgramUniformVariableDataType::Uint32}, - {"blocks_per_col", ProgramUniformVariableDataType::Uint32}, - {"num_N_tile", ProgramUniformVariableDataType::Uint32}, - {"batch_count", ProgramUniformVariableDataType::Uint32}, - {"weight_idx", ProgramUniformVariableDataType::Uint32}); - - private: - uint32_t tile_size_; - bool has_bias_; - bool has_weight_idx_; - bool single_scale_weights_; - uint32_t tile_size_k_vec_; -}; - class MatMulNBitsProgram final : public Program { public: MatMulNBitsProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect, bool single_scale_weights, uint32_t tile_size_k_vec = 16) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_m1.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_m1.wgsl.template deleted file mode 100644 index 149c89ae8017a..0000000000000 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_m1.wgsl.template +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#param a_length_per_tile -#param a_length_per_lane -#param component_a -#param component_b -#param elements_in_value_b -#param has_bias -#param has_weight_idx -#param single_scale_weights -#param subgroup_tile_size -#param tile_size_k -#param tile_size_k_vec - -#use .getByOffset .setByOffset - -#include "quantization/matmul_nbits_zero_pt.wgsl.template" - -fn load_lane_a(batch: u32, kidx: u32, lane_col_base: u32, col: u32) -> input_a_value_t { - let k_offset = kidx / component_a + lane_col_base + col; - if (batch < uniforms.batch_count && k_offset < uniforms.K_of_a) { - return a.getByOffset(batch * uniforms.K_of_a + k_offset); - } - - return input_a_value_t(0); -} - -$MAIN { - let batch = workgroup_id.z; - let subgroup_id = local_id.x / tile_size_k_vec; - let lane_id = local_id.x % tile_size_k_vec; - let b_global_base = workgroup_id.x * subgroup_tile_size * 4u + subgroup_id * subgroup_tile_size; - -#if has_weight_idx - let actual_weight_idx = uniforms.weight_idx; - let b_base_offset = actual_weight_idx * uniforms.K_of_b * uniforms.N; -#if single_scale_weights - let b_scale_offset = actual_weight_idx; -#else - let b_scale_offset = actual_weight_idx * uniforms.N * uniforms.blocks_per_col; -#endif -#else - const b_base_offset : u32 = 0u; - const b_scale_offset : u32 = 0u; - const actual_weight_idx : u32 = 0u; -#endif - -#if has_bias - let b_bias_offset = actual_weight_idx * uniforms.N; -#endif - - var accum : array; - for (var output_idx = 0u; output_idx < subgroup_tile_size; output_idx++) { - accum[output_idx] = output_element_t(0); - } - -#if single_scale_weights - let block_idx = 0u; -#endif - - let lane_col_base = lane_id * a_length_per_lane; - - for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k) { - var lane_a : array; - for (var col = 0u; col < a_length_per_lane; col++) { - lane_a[col] = load_lane_a(batch, kidx, lane_col_base, col); - } - - let k_offset = kidx / elements_in_value_b + lane_id; - if (k_offset < uniforms.K_of_b) { - for (var output_idx = 0u; output_idx < subgroup_tile_size; output_idx++) { - let b_global = b_global_base + output_idx; - if (b_global >= uniforms.N) { - continue; - } - -#if !single_scale_weights - let block_idx = (kidx + lane_id * elements_in_value_b) / uniforms.block_size; -#endif -#if single_scale_weights - let scale_b = scales_b.getByOffset(b_scale_offset); -#else - let scale_b = scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx + b_scale_offset); -#endif - let zero = mm_read_zero(b_global, block_idx, uniforms.N, 0u); - let b_value = b.getByOffset(b_global * uniforms.K_of_b + k_offset + b_base_offset); - - var sum = output_element_t(0); - var a_offset = 0u; - -#if component_b == 1 - let b_value_lower = vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(zero); - let b_value_upper = vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(zero); - let b0 = vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]) * scale_b; - let b1 = vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]) * scale_b; -#if component_a == 1 - sum += dot(vec4(lane_a[a_offset], lane_a[a_offset + 1], lane_a[a_offset + 2], lane_a[a_offset + 3]), b0) + - dot(vec4(lane_a[a_offset + 4], lane_a[a_offset + 5], lane_a[a_offset + 6], lane_a[a_offset + 7]), b1); -#elif component_a == 2 - sum += dot(vec4(lane_a[a_offset], lane_a[a_offset + 1]), b0) + - dot(vec4(lane_a[a_offset + 2], lane_a[a_offset + 3]), b1); -#elif component_a == 4 - sum += dot(lane_a[a_offset], b0) + dot(lane_a[a_offset + 1], b1); -#endif -#else - for (var i = 0u; i < component_b; i++) { - let b_value_lower = vec4(unpack4xU8(b_value[i] & 0x0F0F0F0Fu)) - vec4(zero); - let b_value_upper = vec4(unpack4xU8((b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4(zero); - let b0 = vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]) * scale_b; - let b1 = vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]) * scale_b; -#if component_a == 1 - sum += dot(vec4(lane_a[a_offset], lane_a[a_offset + 1], lane_a[a_offset + 2], lane_a[a_offset + 3]), b0) + - dot(vec4(lane_a[a_offset + 4], lane_a[a_offset + 5], lane_a[a_offset + 6], lane_a[a_offset + 7]), b1); - a_offset += 8; -#elif component_a == 2 - sum += dot(vec4(lane_a[a_offset], lane_a[a_offset + 1]), b0) + - dot(vec4(lane_a[a_offset + 2], lane_a[a_offset + 3]), b1); - a_offset += 4; -#elif component_a == 4 - sum += dot(lane_a[a_offset], b0) + dot(lane_a[a_offset + 1], b1); - a_offset += 2; -#endif - } -#endif - - accum[output_idx] += sum; - } - } - } - - for (var output_idx = 0u; output_idx < subgroup_tile_size; output_idx++) { - var reduced = accum[output_idx]; - if (lane_id < 16u) { - reduced += subgroupShuffle(reduced, lane_id + 16u); - } - if (lane_id < 8u) { - reduced += subgroupShuffle(reduced, lane_id + 8u); - } - if (lane_id < 4u) { - reduced += subgroupShuffle(reduced, lane_id + 4u); - } - if (lane_id < 2u) { - reduced += subgroupShuffle(reduced, lane_id + 2u); - } - if (lane_id < 1u) { - reduced += subgroupShuffle(reduced, lane_id + 1u); - } - - if (lane_id == 0u) { - let b_global = b_global_base + output_idx; - if (b_global < uniforms.N) { - var output_value = reduced; -#if has_bias - output_value += bias[b_global + b_bias_offset]; -#endif - output.setByOffset(batch * uniforms.N + b_global, output_value); - } - } - } -} // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc index 13bb7079462b2..5f8eb9b92f836 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc @@ -24,22 +24,32 @@ constexpr unsigned int kMinMForTileOptimization = 4; constexpr uint32_t kFusedDecodeFastPathBits = 4u; constexpr uint32_t kFusedDecodeFastPathBlockSize = 32u; -bool IsFusedDecodeFastPathEnabled() { - return true; +bool CanApplyDP4AFusedDecodePath(const Tensor* y, + onnxruntime::webgpu::ComputeContext& context, + uint64_t accuracy_level, + uint32_t block_size, + uint32_t N, + uint32_t K) { + if (y->DataType() == DataTypeImpl::GetType()) { + return false; + } + + return CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, 4); } -Status WouldApplyGenericMatMulNBitsInCurrentDispatch(const Tensor* a, - int64_t K_op, - int64_t N_op, - int64_t block_size_op, - int64_t accuracy_level, - int64_t nbits, - onnxruntime::webgpu::ComputeContext& context, - Tensor* y, - bool& would_apply_generic) { +bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t accuracy_level, + int64_t nbits, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y) { TensorShape b_shape({N_op, K_op}); MatMulComputeHelper helper; - ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { + return false; + } const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size()); const uint32_t M = onnxruntime::narrow(helper.M()); @@ -47,51 +57,51 @@ Status WouldApplyGenericMatMulNBitsInCurrentDispatch(const Tensor* a, const uint32_t K = onnxruntime::narrow(helper.K()); const uint32_t block_size = onnxruntime::narrow(block_size_op); - const bool single_scale_weights = (block_size == K * N); - const uint32_t block_size_per_col = single_scale_weights ? K : block_size; - const uint32_t blob_size = (block_size_per_col / 8) * static_cast(nbits); - const uint32_t blob_size_in_words = blob_size / 4; - const uint32_t components_a = GetMaxComponents(K); - const uint32_t components_b = GetMaxComponents(blob_size_in_words); - #if !defined(__wasm__) int32_t subgroup_matrix_config_index = -1; - const bool would_apply_subgroup = - (M >= kMinMForTileOptimization) && - (context.AdapterInfo().vendor == std::string_view{"apple"} || - context.AdapterInfo().vendor == std::string_view{"intel"}) && - CanApplySubgroupMatrixMatMulNBits(context, - accuracy_level, - block_size, - batch_count, - N, - K, - static_cast(nbits), - y->DataType() == DataTypeImpl::GetType(), - subgroup_matrix_config_index); - if (would_apply_subgroup) { - would_apply_generic = false; - return Status::OK(); - } + return (M >= kMinMForTileOptimization) && + (context.AdapterInfo().vendor == std::string_view{"apple"} || + context.AdapterInfo().vendor == std::string_view{"intel"}) && + CanApplySubgroupMatrixMatMulNBits(context, + accuracy_level, + block_size, + batch_count, + N, + K, + static_cast(nbits), + y->DataType() == DataTypeImpl::GetType(), + subgroup_matrix_config_index); #endif - const bool would_apply_dp4a = - ((M >= kMinMForTileOptimization || - y->DataType() == DataTypeImpl::GetType() || - context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && - CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)); - if (would_apply_dp4a) { - would_apply_generic = false; - return Status::OK(); + return false; +} + +bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t nbits) { + TensorShape b_shape({N_op, K_op}); + MatMulComputeHelper helper; + if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { + return false; } - const bool would_apply_wide_tile = block_size == 32 && - components_a == 4 && - components_b == 4 && - nbits != 2 && - M >= kMinMForTileOptimization; - would_apply_generic = !would_apply_wide_tile; - return Status::OK(); + const uint32_t M = onnxruntime::narrow(helper.M()); + const uint32_t K = onnxruntime::narrow(helper.K()); + const uint32_t block_size = onnxruntime::narrow(block_size_op); + + const uint32_t components_a = GetMaxComponents(K); + const uint32_t block_size_per_col = block_size; + const uint32_t blob_size = (block_size_per_col / 8) * static_cast(nbits); + const uint32_t blob_size_in_words = blob_size / 4; + const uint32_t components_b = GetMaxComponents(blob_size_in_words); + + return block_size == 32 && + components_a == 4 && + components_b == 4 && + nbits != 2 && + M >= kMinMForTileOptimization; } class MatMulNBitsSiluMulDecodeProgram final : public Program { @@ -168,6 +178,155 @@ class MatMulNBitsSiluMulDecodeProgram final : public Program { + public: + DP4AMatMulNBitsSiluMulDecodeProgram(uint32_t tile_size_k_vec, + uint32_t tile_size, + bool has_gate_bias, + bool has_up_bias, + bool single_scale_weights) + : Program{"DP4AMatMulNBitsSiluMulDecode"}, + tile_size_k_vec_(tile_size_k_vec), + tile_size_(tile_size), + has_gate_bias_(has_gate_bias), + has_up_bias_(has_up_bias), + single_scale_weights_(single_scale_weights) {} + + Status GenerateShaderCode(ShaderHelper& shader) const override { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform); + const auto& scales_a = shader.AddInput("scales_a", ShaderUsage::UseUniform); + const auto& gate_b = shader.AddInput("gate_b", ShaderUsage::UseUniform); + const auto& gate_scales_b = shader.AddInput("gate_scales_b", ShaderUsage::UseUniform); + const auto& up_b = shader.AddInput("up_b", ShaderUsage::UseUniform); + const auto& up_scales_b = shader.AddInput("up_scales_b", ShaderUsage::UseUniform); + if (has_gate_bias_) { + shader.AddInput("gate_bias", ShaderUsage::UseUniform); + } + if (has_up_bias_) { + shader.AddInput("up_bias", ShaderUsage::UseUniform); + } + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + + return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul_silu_mul.wgsl.template", + WGSL_TEMPLATE_PARAMETER(has_gate_bias, has_gate_bias_), + WGSL_TEMPLATE_PARAMETER(has_up_bias, has_up_bias_), + WGSL_TEMPLATE_PARAMETER(has_zero_points, false), + WGSL_TEMPLATE_PARAMETER(n_bits, 4), + WGSL_TEMPLATE_PARAMETER(output_type_i32, false), + WGSL_TEMPLATE_PARAMETER(single_scale_weights, single_scale_weights_), + WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), + WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec_), + WGSL_TEMPLATE_VARIABLE(a, a), + WGSL_TEMPLATE_VARIABLE(gate_b, gate_b), + WGSL_TEMPLATE_VARIABLE(gate_scales_b, gate_scales_b), + WGSL_TEMPLATE_VARIABLE(output, output), + WGSL_TEMPLATE_VARIABLE(scales_a, scales_a), + WGSL_TEMPLATE_VARIABLE(up_b, up_b), + WGSL_TEMPLATE_VARIABLE(up_scales_b, up_scales_b)); + } + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"batch_count", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"K16", ProgramUniformVariableDataType::Uint32}, + {"K32", ProgramUniformVariableDataType::Uint32}, + {"block_size", ProgramUniformVariableDataType::Uint32}, + {"blocks_per_col", ProgramUniformVariableDataType::Uint32}, + {"num_N_tile", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t tile_size_k_vec_; + uint32_t tile_size_; + bool has_gate_bias_; + bool has_up_bias_; + bool single_scale_weights_; +}; + +Status ApplyDP4AFusedDecodeMatMulNBitsSiluMul(const Tensor* a, + const Tensor* gate_b, + const Tensor* gate_scales, + const Tensor* gate_bias, + const Tensor* up_b, + const Tensor* up_scales, + const Tensor* up_bias, + uint32_t batch_count, + uint32_t N, + uint32_t K, + uint32_t block_size, + uint32_t nbits, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y) { + ORT_ENFORCE(nbits == 4u, "DP4A fused decode path is specialized for 4-bit weights."); + + constexpr uint32_t kVec4Components = 4; + constexpr uint32_t kU32Components = 4; + constexpr uint32_t kBlockSizeA = 128; + + DP4AMatMulQuantizeProgram quantize_program; + quantize_program.SetWorkgroupSize(64); + constexpr uint32_t quantize_tile_size = 64 * kVec4Components; + quantize_program.SetDispatchGroupSize((batch_count * K + quantize_tile_size - 1) / quantize_tile_size, 1, 1); + + TensorShape a_quant_shape{batch_count, 1, K / kU32Components}; + Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType(), a_quant_shape); + TensorShapeVector a_scales_dims({batch_count, 1, 1, K / kBlockSizeA}); + Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims); + + quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}}) + .AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), 1}, + {&a_scale, ProgramTensorMetadataDependency::Rank, 1}}) + .AddUniformVariable({batch_count * K / kU32Components}); + ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); + + const bool has_gate_bias = gate_bias != nullptr; + const bool has_up_bias = up_bias != nullptr; + const bool single_scale_weights = (block_size == K * N); + const uint32_t block_size_per_col = single_scale_weights ? K : block_size; + const uint32_t n_blocks_per_col = (K + block_size_per_col - 1) / block_size_per_col; + const uint32_t blob_size = (block_size_per_col / 8) * nbits; + const uint32_t blob_size_in_words = blob_size / 4; + const uint32_t components_b = GetMaxComponents(blob_size_in_words); + const uint32_t components_b_with_u32 = components_b * kU32Components; + + constexpr uint32_t workgroup_size = 128; + constexpr uint32_t tile_size = 4; + const uint32_t tile_size_k_vec = 32; + const uint32_t num_N_tile = CeilDiv(N, tile_size); + + DP4AMatMulNBitsSiluMulDecodeProgram program{tile_size_k_vec, + tile_size, + has_gate_bias, + has_up_bias, + single_scale_weights}; + program.SetWorkgroupSize(workgroup_size); + program.SetDispatchGroupSize(num_N_tile, 1, batch_count); + program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, + {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1}, + {gate_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + {gate_scales, ProgramTensorMetadataDependency::TypeAndRank, 1}, + {up_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + {up_scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1}) + .AddUniformVariables({batch_count, + N, + K, + K / 16, + K / 32, + block_size, + n_blocks_per_col, + num_N_tile}) + .CacheHint(tile_size_k_vec, tile_size, has_gate_bias, has_up_bias, single_scale_weights, "dp4a_decode_4bit"); + if (has_gate_bias) { + program.AddInput({gate_bias, ProgramTensorMetadataDependency::None}); + } + if (has_up_bias) { + program.AddInput({up_bias, ProgramTensorMetadataDependency::None}); + } + + return context.RunProgram(program); +} + class MatMulNBitsSiluMulProgram final : public Program { public: MatMulNBitsSiluMulProgram() : Program{"MatMulNBitsSiluMul"} {} @@ -227,30 +386,47 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& return Status::OK(); } - bool gate_would_use_generic_matmul = false; - bool up_would_use_generic_matmul = false; - ORT_RETURN_IF_ERROR(WouldApplyGenericMatMulNBitsInCurrentDispatch(a, - K_, - N_, - block_size_, - accuracy_level_, - bits_, - context, - y, - gate_would_use_generic_matmul)); - ORT_RETURN_IF_ERROR(WouldApplyGenericMatMulNBitsInCurrentDispatch(a, - K_, - N_, - block_size_, - accuracy_level_, - bits_, - context, - y, - up_would_use_generic_matmul)); - - if (IsFusedDecodeFastPathEnabled() && M == 1 && bits_ == kFusedDecodeFastPathBits && - block_size == kFusedDecodeFastPathBlockSize && gate_would_use_generic_matmul && - up_would_use_generic_matmul) { + const bool would_use_subgroup_unfused = + WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(a, + K_, + N_, + block_size_, + accuracy_level_, + bits_, + context, + y); + const bool would_use_wide_tile_unfused = + WouldApplyWideTileMatMulNBitsInCurrentDispatch(a, + K_, + N_, + block_size_, + bits_); + + /* + if (!would_use_subgroup_unfused && + M == 1 && bits_ == 4 && + CanApplyDP4AFusedDecodePath(y, context, accuracy_level_, block_size, N, K)) { + return ApplyDP4AFusedDecodeMatMulNBitsSiluMul(a, + gate_b, + gate_scales, + gate_bias, + up_b, + up_scales, + up_bias, + batch_count, + N, + K, + block_size, + onnxruntime::narrow(bits_), + context, + y); + } + */ + + if (!would_use_subgroup_unfused && + !would_use_wide_tile_unfused && + M == 1 && bits_ == kFusedDecodeFastPathBits && + block_size == kFusedDecodeFastPathBlockSize) { ORT_ENFORCE(bits_ == kFusedDecodeFastPathBits, "MatMulNBitsSiluMulDecodeProgram is specialized for 4-bit weights only."); ORT_ENFORCE(block_size == kFusedDecodeFastPathBlockSize, diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 642afcb3f0c29..7f342ec809512 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -13,6 +13,7 @@ #include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h" +#include "core/platform/env_var_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/platform/threadpool.h" @@ -102,12 +103,13 @@ namespace onnxruntime::optimizer_utils { namespace { -bool IsMatMulNBitsSiluFusionEnabled(const SessionOptions& session_options) { - const auto config_value = session_options.config_options.GetConfigOrDefault( - kOrtSessionOptionsEnableMatMulNBitsSiluFusion, - "0"); - return config_value != "0"; +constexpr const char* kOrtEnableMatMulNBitsSiluFusionEnvVar = "ORT_ENABLE_MATMUL_NBITS_SILU_FUSION"; + +#if !defined(ORT_MINIMAL_BUILD) +bool IsMatMulNBitsSiluFusionEnabled() { + return ParseEnvironmentVariableWithDefault(kOrtEnableMatMulNBitsSiluFusionEnvVar, 0) == 1; } +#endif } // namespace @@ -448,9 +450,10 @@ InlinedVector> GenerateTransformers( #endif transformers.emplace_back(std::make_unique(cpu_ep)); - if (IsMatMulNBitsSiluFusionEnabled(session_options)) { - transformers.emplace_back(std::make_unique(InlinedHashSet{onnxruntime::kWebGpuExecutionProvider})); - } + if (IsMatMulNBitsSiluFusionEnabled()) { + transformers.emplace_back(std::make_unique( + InlinedHashSet{onnxruntime::kWebGpuExecutionProvider})); + } #endif // !defined(DISABLE_CONTRIB_OPS) // The QDQFinalCleanupTransformer must run AFTER other transformers that fuse Q/DQ nodes. Otherwise, their diff --git a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc index 901fc97c5c517..4d6b1a5791fbc 100644 --- a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc +++ b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc @@ -17,8 +17,10 @@ #include #include +#include "core/providers/webgpu/webgpu_provider_options.h" #include #include +#include "core/session/onnxruntime_run_options_config_keys.h" #include #include @@ -31,8 +33,10 @@ namespace { constexpr const char* kMatMulNBitsAutoTunerEnvVar = "ORT_WEBGPU_MATMUL_NBITS_ENABLE_AUTO_TUNER"; constexpr const char* kDecodeBenchmarkModeEnvVar = "ORT_WEBGPU_MATMUL_NBITS_BENCHMARK_MODE"; constexpr const char* kDecodeBenchmarkGpuEnvVar = "ORT_WEBGPU_MATMUL_NBITS_BENCHMARK_GPU"; +constexpr const char* kDecodeBenchmarkGraphCaptureEnvVar = "ORT_WEBGPU_MATMUL_NBITS_ENABLE_GRAPH_CAPTURE"; constexpr float kDecodeCorrectnessAbsTolerance = 0.1f; constexpr float kDecodeCorrectnessRelTolerance = 0.01f; +constexpr const char* kBenchmarkGraphCaptureAnnotationId = "1"; enum class DecodeBenchmarkMode { kPerf, @@ -45,6 +49,7 @@ enum class DecodeBenchmarkGpu { }; bool IsMatMulNBitsAutoTunerEnabled(); +bool IsGraphCaptureBenchmarkEnabled(); struct DecodeBenchConfig { int64_t n; @@ -165,9 +170,10 @@ std::string GetDecodeBenchmarkLabel() { const char* mode_label = IsDecodeBenchmarkPerfMode() ? "perf" : "correctness"; const char* adapter_label = GetDecodeBenchmarkGpu() == DecodeBenchmarkGpu::kRtx5060Ti ? "rtx" : "t"; const char* tuner_label = IsMatMulNBitsAutoTunerEnabled() ? "tuner_on" : "tuner_off"; + const char* graph_label = IsGraphCaptureBenchmarkEnabled() ? "graph_on" : "graph_off"; std::ostringstream stream; - stream << "fp16_decode_" << mode_label << '_' << adapter_label << '_' << tuner_label; + stream << "fp16_decode_" << mode_label << '_' << adapter_label << '_' << tuner_label << '_' << graph_label; return stream.str(); } @@ -182,6 +188,26 @@ bool IsMatMulNBitsAutoTunerEnabled() { return auto_tuner_env != "0" && auto_tuner_env != "false" && auto_tuner_env != "off"; } +bool IsGraphCaptureBenchmarkEnabled() { + std::string graph_capture_env = onnxruntime::Env::Default().GetEnvironmentVar(kDecodeBenchmarkGraphCaptureEnvVar); + if (graph_capture_env.empty()) { + return false; + } + + std::transform(graph_capture_env.begin(), graph_capture_env.end(), graph_capture_env.begin(), + [](unsigned char value) { return static_cast(std::tolower(value)); }); + return graph_capture_env != "0" && graph_capture_env != "false" && graph_capture_env != "off"; +} + +Ort::RunOptions CreateBenchmarkRunOptions() { + Ort::RunOptions run_options; + if (IsGraphCaptureBenchmarkEnabled()) { + run_options.AddConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation, kBenchmarkGraphCaptureAnnotationId); + } + + return run_options; +} + std::vector GetRequiredDeviceFeatures(const wgpu::Adapter& adapter) { std::vector required_features; constexpr wgpu::FeatureName features[]{ @@ -729,7 +755,8 @@ std::string GetMlpDecodeBenchmarkLabel(MlpDecodeBenchmarkVariant variant) { stream << "fp16_mlp_decode_" << GetMlpVariantLabel(variant) << '_' << (IsDecodeBenchmarkPerfMode() ? "perf" : "correctness") << '_' << (GetDecodeBenchmarkGpu() == DecodeBenchmarkGpu::kRtx5060Ti ? "rtx" : "t") << '_' - << (IsMatMulNBitsAutoTunerEnabled() ? "tuner_on" : "tuner_off"); + << (IsMatMulNBitsAutoTunerEnabled() ? "tuner_on" : "tuner_off") << '_' + << (IsGraphCaptureBenchmarkEnabled() ? "graph_on" : "graph_off"); return stream.str(); } @@ -853,6 +880,10 @@ Ort::Session CreateSessionFromModelData(const std::vector& model_data, session_options.DisableMemPattern(); session_options.SetGraphOptimizationLevel(graph_optimization_level); if (provider_options != nullptr) { + if (IsGraphCaptureBenchmarkEnabled()) { + session_options.AddConfigEntry(onnxruntime::webgpu::options::kEnableGraphCapture, + onnxruntime::webgpu::options::kEnableGraphCapture_ON); + } session_options.AppendExecutionProvider("WebGPU", *provider_options); } @@ -1010,6 +1041,7 @@ static void BM_WebGpuMatMulNBitsDecode(benchmark::State& state) { activation.size(), input_shape.data(), input_shape.size()); + Ort::RunOptions run_options = CreateBenchmarkRunOptions(); if (!IsDecodeBenchmarkPerfMode()) { ValidateDecodeOutputs(model_data, session, input_names, &input_tensor, output_names); @@ -1017,14 +1049,14 @@ static void BM_WebGpuMatMulNBitsDecode(benchmark::State& state) { // Warm up shader compilation, allocations, and caches before measured iterations. for (int i = 0; i < kDecodeWarmupRuns; ++i) { - auto warmup_outputs = session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, 1); + auto warmup_outputs = session.Run(run_options, input_names, &input_tensor, 1, output_names, 1); benchmark::DoNotOptimize(warmup_outputs); } double total_kernel_seconds = 0.0; for (auto _ : state) { const auto kernel_start = std::chrono::steady_clock::now(); - auto outputs = session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, 1); + auto outputs = session.Run(run_options, input_names, &input_tensor, 1, output_names, 1); const auto kernel_end = std::chrono::steady_clock::now(); total_kernel_seconds += std::chrono::duration(kernel_end - kernel_start).count(); benchmark::DoNotOptimize(outputs); @@ -1050,6 +1082,7 @@ static void BM_WebGpuMatMulNBitsDecode(benchmark::State& state) { state.counters["PackedW_MB"] = benchmark::Counter(traffic.packed_weight_bytes / 1.0e6); state.counters["Scales_MB"] = benchmark::Counter(traffic.scale_bytes / 1.0e6); state.counters["Output_MB"] = benchmark::Counter(traffic.output_bytes / 1.0e6); + state.counters["GraphReplay"] = benchmark::Counter(IsGraphCaptureBenchmarkEnabled() ? 1.0 : 0.0); } catch (const std::exception& ex) { state.SkipWithError(ex.what()); } @@ -1097,6 +1130,7 @@ void BenchmarkWebGpuMatMulNBitsMlpDecode(benchmark::State& state, MlpDecodeBench activation.size(), input_shape.data(), input_shape.size()); + Ort::RunOptions run_options = CreateBenchmarkRunOptions(); if (!IsDecodeBenchmarkPerfMode()) { ValidateMlpDecodeOutputs(SerializeMatMulNBitsMlpModel(config, MlpDecodeBenchmarkVariant::kUnfused), @@ -1108,14 +1142,14 @@ void BenchmarkWebGpuMatMulNBitsMlpDecode(benchmark::State& state, MlpDecodeBench } for (int i = 0; i < kDecodeWarmupRuns; ++i) { - auto warmup_outputs = session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, 1); + auto warmup_outputs = session.Run(run_options, input_names, &input_tensor, 1, output_names, 1); benchmark::DoNotOptimize(warmup_outputs); } double total_kernel_seconds = 0.0; for (auto _ : state) { const auto kernel_start = std::chrono::steady_clock::now(); - auto outputs = session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, 1); + auto outputs = session.Run(run_options, input_names, &input_tensor, 1, output_names, 1); const auto kernel_end = std::chrono::steady_clock::now(); total_kernel_seconds += std::chrono::duration(kernel_end - kernel_start).count(); benchmark::DoNotOptimize(outputs); @@ -1138,6 +1172,7 @@ void BenchmarkWebGpuMatMulNBitsMlpDecode(benchmark::State& state, MlpDecodeBench state.counters["Scales_MB"] = benchmark::Counter(traffic.scale_bytes / 1.0e6); state.counters["Intermediate_MB"] = benchmark::Counter(traffic.intermediate_bytes / 1.0e6); state.counters["Output_MB"] = benchmark::Counter(traffic.output_bytes / 1.0e6); + state.counters["GraphReplay"] = benchmark::Counter(IsGraphCaptureBenchmarkEnabled() ? 1.0 : 0.0); } catch (const std::exception& ex) { state.SkipWithError(ex.what()); } diff --git a/onnxruntime/test/optimizer/graph_transform_utils_test.cc b/onnxruntime/test/optimizer/graph_transform_utils_test.cc index f7bfa3055f96d..ffa79abb4dc09 100644 --- a/onnxruntime/test/optimizer/graph_transform_utils_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_utils_test.cc @@ -4,6 +4,7 @@ #include "core/common/inlined_containers.h" #include "core/graph/onnx_protobuf.h" #include "test/unittest_util/framework_test_utils.h" +#include "test/util/include/scoped_env_vars.h" #include "test/capturing_sink.h" #include "test/test_environment.h" #include "gtest/gtest.h" @@ -16,6 +17,19 @@ using namespace ONNX_NAMESPACE; namespace onnxruntime { namespace test { +namespace { + +constexpr const char* kOrtEnableMatMulNBitsSiluFusionEnvVar = "ORT_ENABLE_MATMUL_NBITS_SILU_FUSION"; + +bool HasTransformerNamed(const InlinedVector>& transformers, + std::string_view name) { + return std::any_of(transformers.begin(), transformers.end(), [&](const auto& transformer) { + return transformer && transformer->Name() == name; + }); +} + +} // namespace + TEST(GraphTransformerUtilsTests, TestGenerateRewriterules) { // Generate all test auto rewrite_rules = optimizer_utils::GenerateRewriteRules(TransformerLevel::Level1); @@ -70,6 +84,35 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); #endif } + +TEST(GraphTransformerUtilsTests, MatMulNBitsSiluFusionDisabledByDefault) { +#if defined(DISABLE_CONTRIB_OPS) + GTEST_SKIP() << "MatMulNBitsSiluFusion requires contrib ops."; +#else + ScopedEnvironmentVariables scoped_env_vars{{kOrtEnableMatMulNBitsSiluFusionEnvVar, {}}}; + + CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); + const auto& logger = DefaultLoggingManager().DefaultLogger(); + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger); + + EXPECT_FALSE(HasTransformerNamed(transformers, "MatMulNBitsSiluFusion")); +#endif +} + +TEST(GraphTransformerUtilsTests, MatMulNBitsSiluFusionEnabledViaEnvironmentVariable) { +#if defined(DISABLE_CONTRIB_OPS) + GTEST_SKIP() << "MatMulNBitsSiluFusion requires contrib ops."; +#else + ScopedEnvironmentVariables scoped_env_vars{{kOrtEnableMatMulNBitsSiluFusionEnvVar, std::string{"1"}}}; + + CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); + const auto& logger = DefaultLoggingManager().DefaultLogger(); + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger); + + EXPECT_TRUE(HasTransformerNamed(transformers, "MatMulNBitsSiluFusion")); +#endif +} + TEST(GraphTransformerUtilsTests, TestDQMatMulNBitsFusionConfigWithContribGating) { SessionOptions session_options; const auto status = session_options.config_options.AddConfigEntry( From ad53b3d931008c40e458cd412ddf6d8414def05b Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Wed, 22 Apr 2026 13:50:57 -0700 Subject: [PATCH 06/26] Worka nd good perf --- .../quantization/matmul_nbits_qkv_sln.cc | 666 ++++++++++++++++++ .../quantization/matmul_nbits_qkv_sln.h | 42 ++ .../matmul_nbits_qkv_sln.wgsl.template | 310 ++++++++ .../webgpu/quantization/matmul_nbits_silu.cc | 289 ++++---- .../matmul_nbits_silu_mul.wgsl.template | 238 ++++--- ..._nbits_silu_mul_wide_tile_m1.wgsl.template | 127 ++++ .../webgpu/webgpu_contrib_kernels.cc | 2 + .../core/graph/contrib_ops/contrib_defs.cc | 79 +++ .../core/optimizer/graph_transformer_utils.cc | 23 +- .../optimizer/matmul_nbits_qkv_sln_fusion.cc | 289 ++++++++ .../optimizer/matmul_nbits_qkv_sln_fusion.h | 20 + .../core/providers/webgpu/program_manager.cc | 84 ++- .../webgpu_matmul_nbits_decode.cc | 543 +++++++++++++- .../optimizer/graph_transform_utils_test.cc | 38 +- .../matmul_nbits_qkv_sln_fusion_test.cc | 228 ++++++ 15 files changed, 2678 insertions(+), 300 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.cc create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.h create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.wgsl.template create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul_wide_tile_m1.wgsl.template create mode 100644 onnxruntime/core/optimizer/matmul_nbits_qkv_sln_fusion.cc create mode 100644 onnxruntime/core/optimizer/matmul_nbits_qkv_sln_fusion.h create mode 100644 onnxruntime/test/optimizer/matmul_nbits_qkv_sln_fusion_test.cc diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.cc new file mode 100644 index 0000000000000..0902dcb1617fb --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.cc @@ -0,0 +1,666 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.h" + +#include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" +#include "contrib_ops/webgpu/quantization/matmul_nbits.h" +#include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" +#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" +#include "contrib_ops/webgpu/bert/skip_layer_norm.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/webgpu/nn/layer_norm.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +namespace { + +constexpr unsigned int kMinMForTileOptimization = 4; + +bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t accuracy_level, + int64_t nbits, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y) { + TensorShape b_shape({N_op, K_op}); + MatMulComputeHelper helper; + if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { + return false; + } + + const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size()); + const uint32_t M = onnxruntime::narrow(helper.M()); + const uint32_t N = onnxruntime::narrow(helper.N()); + const uint32_t K = onnxruntime::narrow(helper.K()); + const uint32_t block_size = onnxruntime::narrow(block_size_op); + +#if !defined(__wasm__) + int32_t subgroup_matrix_config_index = -1; + return (M >= kMinMForTileOptimization) && + (context.AdapterInfo().vendor == std::string_view{"apple"} || + context.AdapterInfo().vendor == std::string_view{"intel"}) && + CanApplySubgroupMatrixMatMulNBits(context, + accuracy_level, + block_size, + batch_count, + N, + K, + static_cast(nbits), + y->DataType() == DataTypeImpl::GetType(), + subgroup_matrix_config_index); +#endif + + return false; +} + +bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t accuracy_level, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y) { + TensorShape b_shape({N_op, K_op}); + MatMulComputeHelper helper; + if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { + return false; + } + + const uint32_t M = onnxruntime::narrow(helper.M()); + const uint32_t N = onnxruntime::narrow(helper.N()); + const uint32_t K = onnxruntime::narrow(helper.K()); + const uint32_t block_size = onnxruntime::narrow(block_size_op); + const uint32_t components_a = GetMaxComponents(K); + + return ((M >= kMinMForTileOptimization) || + y->DataType() == DataTypeImpl::GetType() || + context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && + CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a); +} + +bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t nbits) { + TensorShape b_shape({N_op, K_op}); + MatMulComputeHelper helper; + if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { + return false; + } + + const uint32_t M = onnxruntime::narrow(helper.M()); + const uint32_t K = onnxruntime::narrow(helper.K()); + const uint32_t block_size = onnxruntime::narrow(block_size_op); + const uint32_t components_a = GetMaxComponents(K); + const uint32_t block_size_per_col = block_size; + const uint32_t blob_size = (block_size_per_col / 8) * static_cast(nbits); + const uint32_t blob_size_in_words = blob_size / 4; + const uint32_t components_b = GetMaxComponents(blob_size_in_words); + + return block_size == 32 && + components_a == 4 && + components_b == 4 && + nbits != 2 && + M >= kMinMForTileOptimization; +} + +TensorShape GetOverrideShape(const TensorShape& shape, int components) { + return TensorShape{shape.Size() / components}; +} + +Status ApplySimplifiedLayerNorm(const Tensor* x, + const Tensor* scale, + float epsilon, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y) { + const auto& x_shape = x->Shape(); + if (x_shape.Size() == 0) { + return Status::OK(); + } + + const int64_t norm_size = x_shape[x_shape.NumDimensions() - 1]; + const uint32_t norm_count = onnxruntime::narrow(x_shape.Size() / norm_size); + const int components = GetMaxComponents(norm_size); + const uint32_t norm_size_vectorized = onnxruntime::narrow((norm_size + components - 1) / components); + const bool split_norm_dim = norm_size % 512 == 0 && norm_count == 1; + + onnxruntime::webgpu::LayerNormProgram program{/*has_bias=*/false, + /*simplified=*/true, + /*has_mean_output=*/false, + /*has_inv_std_dev_output=*/false, + split_norm_dim}; + + program.CacheHint(components, true, split_norm_dim) + .AddInputs({{x, ProgramTensorMetadataDependency::Type, GetOverrideShape(x_shape, components), components}, + {scale, ProgramTensorMetadataDependency::Type, GetOverrideShape(scale->Shape(), components), components}}) + .AddOutputs({{y, ProgramTensorMetadataDependency::None, GetOverrideShape(y->Shape(), components), components}}) + .AddUniformVariables({{static_cast(components)}, + {norm_count}, + {static_cast(norm_size)}, + {norm_size_vectorized}, + {epsilon}}); + + if (split_norm_dim) { + const uint32_t workgroup_size_x = 128; + const uint32_t dispatch_size_x = onnxruntime::narrow(norm_size / (workgroup_size_x * components)); + program.SetDispatchGroupSize(dispatch_size_x, 1, 1) + .SetWorkgroupSize(workgroup_size_x); + } else { + program.SetDispatchGroupSize(norm_count); + } + + return context.RunProgram(program); +} + +Status ApplySkipSimplifiedLayerNorm(const Tensor* x, + const Tensor* skip, + const Tensor* scale, + float epsilon, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + Tensor* input_skip_bias_sum) { + const auto& x_shape = x->Shape(); + if (x_shape.Size() == 0) { + return Status::OK(); + } + + const uint32_t hidden_size = onnxruntime::narrow(x_shape[x_shape.NumDimensions() - 1]); + const int components = GetMaxComponents(hidden_size); + const uint32_t norm_count = onnxruntime::narrow(x_shape.SizeToDimension(x_shape.NumDimensions() - 1)); + const bool split_hidden_dim = hidden_size % 512 == 0 && norm_count == 1; + const uint32_t skip_size = onnxruntime::narrow(skip->Shape().Size()); + + SkipLayerNormProgram program{/*hasBeta=*/false, + /*hasBias=*/false, + epsilon, + hidden_size, + input_skip_bias_sum != nullptr, + /*simplified=*/true, + split_hidden_dim}; + program + .CacheHint(/*simplified=*/true, input_skip_bias_sum != nullptr, split_hidden_dim) + .AddInputs({{x, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{scale, ProgramTensorMetadataDependency::Type, components}}) + .AddOutputs({{y, ProgramTensorMetadataDependency::None, components}}) + .SetDispatchGroupSize(onnxruntime::narrow(ceil(1.0 * x_shape.Size() / hidden_size))) + .AddUniformVariables({{static_cast(components)}}) + .AddUniformVariables({{hidden_size}}) + .AddUniformVariables({{epsilon}}) + .AddUniformVariables({{skip_size}}); + + if (split_hidden_dim) { + const uint32_t workgroup_size_x = 128; + const uint32_t dispatch_size_x = (input_skip_bias_sum != nullptr ? 2u : 1u) * hidden_size / (workgroup_size_x * components); + program.SetDispatchGroupSize(dispatch_size_x, 1, 1) + .SetWorkgroupSize(workgroup_size_x); + } + + if (input_skip_bias_sum != nullptr) { + program.AddOutputs({{input_skip_bias_sum, ProgramTensorMetadataDependency::None, components}}); + } + + return context.RunProgram(program); +} + +Status ApplyUnfusedQKVSimplifiedLayerNorm(const Tensor* a, + const Tensor* norm_scale, + const Tensor* q_b, + const Tensor* q_scales, + const Tensor* k_b, + const Tensor* k_scales, + const Tensor* v_b, + const Tensor* v_scales, + int64_t K, + int64_t Nq, + int64_t Nkv, + int64_t block_size, + int64_t accuracy_level, + int64_t bits, + float epsilon, + onnxruntime::webgpu::ComputeContext& context, + Tensor* q_output, + Tensor* k_output, + Tensor* v_output) { + Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape()); + ORT_RETURN_IF_ERROR(ApplySimplifiedLayerNorm(a, norm_scale, epsilon, context, &normalized_a)); + ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, q_b, q_scales, nullptr, nullptr, + K, Nq, block_size, accuracy_level, bits, context, q_output)); + ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, k_b, k_scales, nullptr, nullptr, + K, Nkv, block_size, accuracy_level, bits, context, k_output)); + ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, v_b, v_scales, nullptr, nullptr, + K, Nkv, block_size, accuracy_level, bits, context, v_output)); + return Status::OK(); +} + +Status ApplyUnfusedQKVSkipSimplifiedLayerNorm(const Tensor* a, + const Tensor* skip, + const Tensor* norm_scale, + const Tensor* q_b, + const Tensor* q_scales, + const Tensor* k_b, + const Tensor* k_scales, + const Tensor* v_b, + const Tensor* v_scales, + int64_t K, + int64_t Nq, + int64_t Nkv, + int64_t block_size, + int64_t accuracy_level, + int64_t bits, + float epsilon, + onnxruntime::webgpu::ComputeContext& context, + Tensor* q_output, + Tensor* k_output, + Tensor* v_output, + Tensor* input_skip_bias_sum) { + Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape()); + ORT_RETURN_IF_ERROR(ApplySkipSimplifiedLayerNorm(a, skip, norm_scale, epsilon, context, &normalized_a, input_skip_bias_sum)); + ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, q_b, q_scales, nullptr, nullptr, + K, Nq, block_size, accuracy_level, bits, context, q_output)); + ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, k_b, k_scales, nullptr, nullptr, + K, Nkv, block_size, accuracy_level, bits, context, k_output)); + ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, v_b, v_scales, nullptr, nullptr, + K, Nkv, block_size, accuracy_level, bits, context, v_output)); + return Status::OK(); +} + +class MatMulNBitsQKVSimplifiedLayerNormDecodeProgram final + : public Program { + public: + MatMulNBitsQKVSimplifiedLayerNormDecodeProgram(uint32_t tile_size, + bool single_scale_weights, + uint32_t tile_size_k_vec, + uint32_t k_unroll_tiles, + bool has_full_q_tiles, + bool has_full_kv_tiles, + bool has_full_k_tiles, + bool has_skip_input, + bool has_skip_output) + : Program{"MatMulNBitsQKVSimplifiedLayerNormDecode"}, + tile_size_(tile_size), + single_scale_weights_(single_scale_weights), + tile_size_k_vec_(tile_size_k_vec), + k_unroll_tiles_(k_unroll_tiles), + has_full_q_tiles_(has_full_q_tiles), + has_full_kv_tiles_(has_full_kv_tiles), + has_full_k_tiles_(has_full_k_tiles), + has_skip_input_(has_skip_input), + has_skip_output_(has_skip_output) {} + + Status GenerateShaderCode(ShaderHelper& shader) const override { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto* skip = has_skip_input_ ? &shader.AddInput("skip", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias) : nullptr; + const auto& norm_scale = shader.AddInput("norm_scale", ShaderUsage::UseValueTypeAlias); + const auto& q_b = shader.AddInput("q_b", ShaderUsage::UseValueTypeAlias); + const auto& q_scales_b = shader.AddInput("q_scales_b"); + const auto& k_b = shader.AddInput("k_b"); + const auto& k_scales_b = shader.AddInput("k_scales_b"); + const auto& v_b = shader.AddInput("v_b"); + const auto& v_scales_b = shader.AddInput("v_scales_b"); + const auto& q_output = shader.AddOutput("q_output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& k_output = shader.AddOutput("k_output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& v_output = shader.AddOutput("v_output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto* input_skip_bias_sum = has_skip_output_ ? &shader.AddOutput("input_skip_bias_sum", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias) : nullptr; + + const uint32_t components_a = a.NumComponents(); + const uint32_t components_b = q_b.NumComponents() / 4; + const uint32_t tile_size_k_vec = tile_size_k_vec_; + const uint32_t elements_in_value_b = components_b * 8u; + const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b; + const uint32_t a_length_per_tile = tile_size_k / components_a; + const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec; + + if (skip != nullptr) { + if (input_skip_bias_sum != nullptr) { + return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_qkv_sln.wgsl.template", + WGSL_TEMPLATE_PARAMETER(a_length_per_tile, a_length_per_tile), + WGSL_TEMPLATE_PARAMETER(component_a, components_a), + WGSL_TEMPLATE_PARAMETER(component_b, components_b), + WGSL_TEMPLATE_PARAMETER(elements_in_value_b, elements_in_value_b), + WGSL_TEMPLATE_PARAMETER(has_full_k_tiles, has_full_k_tiles_), + WGSL_TEMPLATE_PARAMETER(has_full_kv_tiles, has_full_kv_tiles_), + WGSL_TEMPLATE_PARAMETER(has_full_q_tiles, has_full_q_tiles_), + WGSL_TEMPLATE_PARAMETER(has_skip_input, has_skip_input_), + WGSL_TEMPLATE_PARAMETER(has_skip_output, has_skip_output_), + WGSL_TEMPLATE_PARAMETER(k_unroll_tiles, k_unroll_tiles_), + WGSL_TEMPLATE_PARAMETER(single_scale_weights, single_scale_weights_), + WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count), + WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), + WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k), + WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), + WGSL_TEMPLATE_VARIABLE(a, a), + WGSL_TEMPLATE_VARIABLE(input_skip_bias_sum, *input_skip_bias_sum), + WGSL_TEMPLATE_VARIABLE(k_b, k_b), + WGSL_TEMPLATE_VARIABLE(k_output, k_output), + WGSL_TEMPLATE_VARIABLE(k_scales_b, k_scales_b), + WGSL_TEMPLATE_VARIABLE(norm_scale, norm_scale), + WGSL_TEMPLATE_VARIABLE(q_b, q_b), + WGSL_TEMPLATE_VARIABLE(q_output, q_output), + WGSL_TEMPLATE_VARIABLE(q_scales_b, q_scales_b), + WGSL_TEMPLATE_VARIABLE(skip, *skip), + WGSL_TEMPLATE_VARIABLE(v_b, v_b), + WGSL_TEMPLATE_VARIABLE(v_output, v_output), + WGSL_TEMPLATE_VARIABLE(v_scales_b, v_scales_b)); + } + + return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_qkv_sln.wgsl.template", + WGSL_TEMPLATE_PARAMETER(a_length_per_tile, a_length_per_tile), + WGSL_TEMPLATE_PARAMETER(component_a, components_a), + WGSL_TEMPLATE_PARAMETER(component_b, components_b), + WGSL_TEMPLATE_PARAMETER(elements_in_value_b, elements_in_value_b), + WGSL_TEMPLATE_PARAMETER(has_full_k_tiles, has_full_k_tiles_), + WGSL_TEMPLATE_PARAMETER(has_full_kv_tiles, has_full_kv_tiles_), + WGSL_TEMPLATE_PARAMETER(has_full_q_tiles, has_full_q_tiles_), + WGSL_TEMPLATE_PARAMETER(has_skip_input, has_skip_input_), + WGSL_TEMPLATE_PARAMETER(has_skip_output, has_skip_output_), + WGSL_TEMPLATE_PARAMETER(k_unroll_tiles, k_unroll_tiles_), + WGSL_TEMPLATE_PARAMETER(single_scale_weights, single_scale_weights_), + WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count), + WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), + WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k), + WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), + WGSL_TEMPLATE_VARIABLE(a, a), + WGSL_TEMPLATE_VARIABLE(k_b, k_b), + WGSL_TEMPLATE_VARIABLE(k_output, k_output), + WGSL_TEMPLATE_VARIABLE(k_scales_b, k_scales_b), + WGSL_TEMPLATE_VARIABLE(norm_scale, norm_scale), + WGSL_TEMPLATE_VARIABLE(q_b, q_b), + WGSL_TEMPLATE_VARIABLE(q_output, q_output), + WGSL_TEMPLATE_VARIABLE(q_scales_b, q_scales_b), + WGSL_TEMPLATE_VARIABLE(skip, *skip), + WGSL_TEMPLATE_VARIABLE(v_b, v_b), + WGSL_TEMPLATE_VARIABLE(v_output, v_output), + WGSL_TEMPLATE_VARIABLE(v_scales_b, v_scales_b)); + } + + return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_qkv_sln.wgsl.template", + WGSL_TEMPLATE_PARAMETER(a_length_per_tile, a_length_per_tile), + WGSL_TEMPLATE_PARAMETER(component_a, components_a), + WGSL_TEMPLATE_PARAMETER(component_b, components_b), + WGSL_TEMPLATE_PARAMETER(elements_in_value_b, elements_in_value_b), + WGSL_TEMPLATE_PARAMETER(has_full_k_tiles, has_full_k_tiles_), + WGSL_TEMPLATE_PARAMETER(has_full_kv_tiles, has_full_kv_tiles_), + WGSL_TEMPLATE_PARAMETER(has_full_q_tiles, has_full_q_tiles_), + WGSL_TEMPLATE_PARAMETER(has_skip_input, has_skip_input_), + WGSL_TEMPLATE_PARAMETER(has_skip_output, has_skip_output_), + WGSL_TEMPLATE_PARAMETER(k_unroll_tiles, k_unroll_tiles_), + WGSL_TEMPLATE_PARAMETER(single_scale_weights, single_scale_weights_), + WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count), + WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), + WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k), + WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), + WGSL_TEMPLATE_VARIABLE(a, a), + WGSL_TEMPLATE_VARIABLE(k_b, k_b), + WGSL_TEMPLATE_VARIABLE(k_output, k_output), + WGSL_TEMPLATE_VARIABLE(k_scales_b, k_scales_b), + WGSL_TEMPLATE_VARIABLE(norm_scale, norm_scale), + WGSL_TEMPLATE_VARIABLE(q_b, q_b), + WGSL_TEMPLATE_VARIABLE(q_output, q_output), + WGSL_TEMPLATE_VARIABLE(q_scales_b, q_scales_b), + WGSL_TEMPLATE_VARIABLE(v_b, v_b), + WGSL_TEMPLATE_VARIABLE(v_output, v_output), + WGSL_TEMPLATE_VARIABLE(v_scales_b, v_scales_b)); + } + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"Nq", ProgramUniformVariableDataType::Uint32}, + {"Nkv", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"K_of_a", ProgramUniformVariableDataType::Uint32}, + {"K_of_b", ProgramUniformVariableDataType::Uint32}, + {"block_size", ProgramUniformVariableDataType::Uint32}, + {"blocks_per_col", ProgramUniformVariableDataType::Uint32}, + {"num_N_tile", ProgramUniformVariableDataType::Uint32}, + {"batch_count", ProgramUniformVariableDataType::Uint32}, + {"skip_size", ProgramUniformVariableDataType::Uint32}, + {"epsilon", ProgramUniformVariableDataType::Float32}); + + private: + uint32_t tile_size_; + bool single_scale_weights_; + uint32_t tile_size_k_vec_; + uint32_t k_unroll_tiles_; + bool has_full_q_tiles_; + bool has_full_kv_tiles_; + bool has_full_k_tiles_; + bool has_skip_input_; + bool has_skip_output_; +}; + +} // namespace + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBitsQKVSimplifiedLayerNorm, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", WebGpuSupportedFloatTypes()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBitsQKVSimplifiedLayerNorm); + +Status MatMulNBitsQKVSimplifiedLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* a = context.Input(0); + const Tensor* skip = context.Input(1); + const Tensor* norm_scale = context.Input(2); + const Tensor* q_b = context.Input(3); + const Tensor* q_scales = context.Input(4); + const Tensor* k_b = context.Input(5); + const Tensor* k_scales = context.Input(6); + const Tensor* v_b = context.Input(7); + const Tensor* v_scales = context.Input(8); + + ORT_ENFORCE(bits_ == 4, "MatMulNBitsQKVSimplifiedLayerNorm currently supports 4-bit weights only."); + ORT_ENFORCE(block_size_ == 32, "MatMulNBitsQKVSimplifiedLayerNorm currently supports block_size=32 only."); + + TensorShape q_b_shape({Nq_, K_}); + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), q_b_shape, false, true)); + + const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size()); + const uint32_t M = onnxruntime::narrow(helper.M()); + const uint32_t K = onnxruntime::narrow(helper.K()); + const uint32_t Nq = onnxruntime::narrow(Nq_); + const uint32_t Nkv = onnxruntime::narrow(Nkv_); + + auto q_shape = helper.OutputShape(); + TensorShapeVector kv_dims(q_shape.GetDims().begin(), q_shape.GetDims().end()); + kv_dims.back() = Nkv_; + TensorShape kv_shape(kv_dims); + Tensor* q_output = context.Output(0, q_shape); + Tensor* k_output = context.Output(1, kv_shape); + Tensor* v_output = context.Output(2, kv_shape); + Tensor* input_skip_bias_sum = skip != nullptr ? context.Output(3, a->Shape()) : nullptr; + if (q_output->Shape().Size() == 0) { + return Status::OK(); + } + + ORT_ENFORCE(norm_scale->Shape().Size() == K_, "norm_scale must have shape [K]."); + + const bool would_use_subgroup_unfused = + WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(a, + K_, + Nq_, + block_size_, + accuracy_level_, + bits_, + context, + q_output); + const bool would_use_dp4a_unfused = + !would_use_subgroup_unfused && + WouldApplyDP4AMatMulNBitsInCurrentDispatch(a, + K_, + Nq_, + block_size_, + accuracy_level_, + context, + q_output); + const bool would_use_wide_tile_unfused = + !would_use_subgroup_unfused && + !would_use_dp4a_unfused && + WouldApplyWideTileMatMulNBitsInCurrentDispatch(a, + K_, + Nq_, + block_size_, + bits_); + + if (would_use_subgroup_unfused || would_use_dp4a_unfused || would_use_wide_tile_unfused || M != 1) { + if (skip != nullptr) { + return ApplyUnfusedQKVSkipSimplifiedLayerNorm(a, + skip, + norm_scale, + q_b, + q_scales, + k_b, + k_scales, + v_b, + v_scales, + K_, + Nq_, + Nkv_, + block_size_, + accuracy_level_, + bits_, + epsilon_, + context, + q_output, + k_output, + v_output, + input_skip_bias_sum); + } + return ApplyUnfusedQKVSimplifiedLayerNorm(a, + norm_scale, + q_b, + q_scales, + k_b, + k_scales, + v_b, + v_scales, + K_, + Nq_, + Nkv_, + block_size_, + accuracy_level_, + bits_, + epsilon_, + context, + q_output, + k_output, + v_output); + } + + const uint32_t block_size = onnxruntime::narrow(block_size_); + const uint32_t components_a = GetMaxComponents(K); + const uint32_t block_size_per_col = block_size; + const uint32_t n_blocks_per_col = (K + block_size_per_col - 1) / block_size_per_col; + const uint32_t blob_size = (block_size_per_col / 8) * static_cast(bits_); + const uint32_t blob_size_in_words = blob_size / 4; + const uint32_t components_b = GetMaxComponents(blob_size_in_words); + constexpr uint32_t kU32Components = 4; + const uint32_t components_b_with_u32 = components_b * kU32Components; + const uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; + const bool single_scale_weights = + q_scales->Shape().Size() == 1 && k_scales->Shape().Size() == 1 && v_scales->Shape().Size() == 1; + + uint32_t workgroup_size = 128; + uint32_t tile_size = 8; + uint32_t tile_size_k_vec = (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u; + if (context.AdapterInfo().vendor != std::string_view{"intel"} && std::max(Nq, Nkv) <= 2048) { + workgroup_size = 64; + tile_size = 4; + tile_size_k_vec = 16; + } + + const uint32_t elements_in_value_b = components_b * (32u / onnxruntime::narrow(bits_)); + const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b; + const bool has_full_q_tiles = (Nq % tile_size) == 0; + const bool has_full_kv_tiles = (Nkv % tile_size) == 0; + const bool has_full_k_tiles = (K % tile_size_k) == 0; + const uint32_t k_tile_iterations = K / tile_size_k; + + uint32_t k_unroll_tiles = 1; + if (has_full_k_tiles) { + if (k_tile_iterations >= 8 && std::max(Nq, Nkv) <= 2048 && + context.AdapterInfo().vendor != std::string_view{"intel"}) { + k_unroll_tiles = 4; + } else if (k_tile_iterations >= 4) { + k_unroll_tiles = 2; + } + } + + const uint32_t num_N_tile = CeilDiv(std::max(Nq, Nkv), tile_size); + MatMulNBitsQKVSimplifiedLayerNormDecodeProgram program{tile_size, + single_scale_weights, + tile_size_k_vec, + k_unroll_tiles, + has_full_q_tiles, + has_full_kv_tiles, + has_full_k_tiles, + skip != nullptr, + input_skip_bias_sum != nullptr}; + program.SetWorkgroupSize(workgroup_size); + program.SetDispatchGroupSize(num_N_tile, 1, batch_count); + program + .AddInput({a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); + if (skip != nullptr) { + program.AddInput({skip, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); + } + program + .AddInputs({{norm_scale, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}, + {q_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + {q_scales, ProgramTensorMetadataDependency::TypeAndRank}, + {k_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + {k_scales, ProgramTensorMetadataDependency::TypeAndRank}, + {v_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + {v_scales, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({{q_output, ProgramTensorMetadataDependency::TypeAndRank}, + {k_output, ProgramTensorMetadataDependency::TypeAndRank}, + {v_output, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddUniformVariables({{Nq}, + {Nkv}, + {K}, + {K / components_a}, + {K_of_b}, + {block_size}, + {n_blocks_per_col}, + {num_N_tile}, + {batch_count}, + {skip != nullptr ? onnxruntime::narrow(skip->Shape().Size()) : 0u}, + {epsilon_}}) + .CacheHint(Nq, + Nkv, + K, + tile_size, + tile_size_k_vec, + k_unroll_tiles, + has_full_q_tiles, + has_full_kv_tiles, + has_full_k_tiles, + single_scale_weights, + skip != nullptr, + input_skip_bias_sum != nullptr, + "decode_qkv_sln"); + + if (input_skip_bias_sum != nullptr) { + program.AddOutput({input_skip_bias_sum, + ProgramTensorMetadataDependency::TypeAndRank, + static_cast(components_a)}); + } + + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.h new file mode 100644 index 0000000000000..810ffbcdf4885 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class MatMulNBitsQKVSimplifiedLayerNorm final : public WebGpuKernel { + public: + explicit MatMulNBitsQKVSimplifiedLayerNorm(const OpKernelInfo& info) : WebGpuKernel(info) { + K_ = info.GetAttr("K"); + Nq_ = info.GetAttr("Nq"); + Nkv_ = info.GetAttr("Nkv"); + block_size_ = info.GetAttr("block_size"); + bits_ = info.GetAttr("bits"); + accuracy_level_ = info.GetAttrOrDefault("accuracy_level", 4); + epsilon_ = info.GetAttrOrDefault("epsilon", 1e-6f); + ORT_ENFORCE(bits_ == 4, + "MatMulNBitsQKVSimplifiedLayerNorm currently supports 4-bit weights only."); + } + + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; + + private: + int64_t K_; + int64_t Nq_; + int64_t Nkv_; + int64_t block_size_; + int64_t accuracy_level_; + int64_t bits_; + float epsilon_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.wgsl.template new file mode 100644 index 0000000000000..18d3aa4270c67 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.wgsl.template @@ -0,0 +1,310 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param a_length_per_tile +#param component_a +#param component_b +#param elements_in_value_b +#param has_full_k_tiles +#param has_full_kv_tiles +#param has_full_q_tiles +#param has_skip_input +#param has_skip_output +#param k_unroll_tiles +#param single_scale_weights +#param sub_tile_count +#param tile_size_k_vec +#param tile_size_k +#param tile_size + +#use .getByOffset .setByOffset + +var sum_squared_shared : array; +var tile_A : array; +var q_inter_results : array, tile_size>; +var k_inter_results : array, tile_size>; +var v_inter_results : array, tile_size>; + +const default_zero_point = vec4(q_output_element_t(8)); + +fn unpack_nibble_values(word: u32) -> vec4 { + let unpacked = unpack4xU8(word); + return vec4(q_output_element_t(unpacked[0]), + q_output_element_t(unpacked[1]), + q_output_element_t(unpacked[2]), + q_output_element_t(unpacked[3])); +} + +fn load_merged_input(input_offset: u32) -> input_a_value_t { + var value = a.getByOffset(input_offset); +#if has_skip_input + let skip_offset = input_offset % (uniforms.skip_size / component_a); + value += input_a_value_t(skip.getByOffset(skip_offset)); +#endif + return value; +} + +#if component_a == 1 +fn load_a_vec4(a_offset: u32) -> vec4 { + return vec4(q_output_element_t(tile_A[a_offset]), + q_output_element_t(tile_A[a_offset + 1]), + q_output_element_t(tile_A[a_offset + 2]), + q_output_element_t(tile_A[a_offset + 3])); +} +#elif component_a == 2 +fn load_a_vec4(a_offset: u32) -> vec4 { + let a0 = tile_A[a_offset]; + let a1 = tile_A[a_offset + 1]; + return vec4(q_output_element_t(a0[0]), + q_output_element_t(a0[1]), + q_output_element_t(a1[0]), + q_output_element_t(a1[1])); +} +#elif component_a == 4 +fn load_a_vec4(a_offset: u32) -> vec4 { + let a = tile_A[a_offset]; + return vec4(q_output_element_t(a[0]), + q_output_element_t(a[1]), + q_output_element_t(a[2]), + q_output_element_t(a[3])); +} +#endif + +fn loadSHMA(batch: u32, b_global_base: u32, kidx: u32, col: u32, inv_std: f32) { + let k_offset = kidx / component_a + col; + let input_offset = batch * uniforms.K_of_a + k_offset; +#if has_full_k_tiles + let merged_value = load_merged_input(input_offset); +#if has_skip_output + if (b_global_base == 0u) { + input_skip_bias_sum.setByOffset(input_offset, input_skip_bias_sum_value_t(merged_value)); + } +#endif + tile_A[col] = merged_value * input_a_value_t(input_a_element_t(inv_std)) * norm_scale.getByOffset(k_offset); +#else + if (k_offset < uniforms.K_of_a) { + let merged_value = load_merged_input(input_offset); +#if has_skip_output + if (b_global_base == 0u) { + input_skip_bias_sum.setByOffset(input_offset, input_skip_bias_sum_value_t(merged_value)); + } +#endif + tile_A[col] = merged_value * input_a_value_t(input_a_element_t(inv_std)) * norm_scale.getByOffset(k_offset); + } else { + tile_A[col] = input_a_value_t(0); + } +#endif +} + +fn compute_projection_sum(weight: q_b_value_t, + scale: q_output_element_t, + idx: u32) -> q_output_element_t { + var sum = q_output_element_t(0); + var a_offset = idx * (8 / component_a) * component_b; +#if component_b == 1 + let weight_lower = unpack_nibble_values(weight & 0x0F0F0F0Fu) - default_zero_point; + let weight_upper = unpack_nibble_values((weight >> 4) & 0x0F0F0F0Fu) - default_zero_point; + let w0 = vec4(q_output_element_t(weight_lower[0]), q_output_element_t(weight_upper[0]), q_output_element_t(weight_lower[1]), q_output_element_t(weight_upper[1])) * scale; + let w1 = vec4(q_output_element_t(weight_lower[2]), q_output_element_t(weight_upper[2]), q_output_element_t(weight_lower[3]), q_output_element_t(weight_upper[3])) * scale; +#if component_a == 1 + let a0 = load_a_vec4(a_offset); + let a1 = load_a_vec4(a_offset + 4); + sum += dot(a0, w0) + dot(a1, w1); +#elif component_a == 2 + let a0 = load_a_vec4(a_offset); + let a1 = load_a_vec4(a_offset + 2); + sum += dot(a0, w0) + dot(a1, w1); +#elif component_a == 4 + let a0 = load_a_vec4(a_offset); + let a1 = load_a_vec4(a_offset + 1); + sum += dot(a0, w0) + dot(a1, w1); +#endif +#else + for (var i = 0u; i < component_b; i++) { + let weight_lower = unpack_nibble_values(weight[i] & 0x0F0F0F0Fu) - default_zero_point; + let weight_upper = unpack_nibble_values((weight[i] >> 4) & 0x0F0F0F0Fu) - default_zero_point; + let w0 = vec4(q_output_element_t(weight_lower[0]), q_output_element_t(weight_upper[0]), q_output_element_t(weight_lower[1]), q_output_element_t(weight_upper[1])) * scale; + let w1 = vec4(q_output_element_t(weight_lower[2]), q_output_element_t(weight_upper[2]), q_output_element_t(weight_lower[3]), q_output_element_t(weight_upper[3])) * scale; +#if component_a == 1 + let a0 = load_a_vec4(a_offset); + let a1 = load_a_vec4(a_offset + 4); + sum += dot(a0, w0) + dot(a1, w1); + a_offset += 8; +#elif component_a == 2 + let a0 = load_a_vec4(a_offset); + let a1 = load_a_vec4(a_offset + 2); + sum += dot(a0, w0) + dot(a1, w1); + a_offset += 4; +#elif component_a == 4 + let a0 = load_a_vec4(a_offset); + let a1 = load_a_vec4(a_offset + 1); + sum += dot(a0, w0) + dot(a1, w1); + a_offset += 2; +#endif + } +#endif + return sum; +} + +fn process_k_tile(batch: u32, b_global_base: u32, thread_idx: u32, idx: u32, idy: u32, kidx: u32, inv_std: f32) { + for (var id = thread_idx; id < a_length_per_tile; id += workgroup_size_x) { + loadSHMA(batch, b_global_base, kidx, id, inv_std); + } + workgroupBarrier(); + + for (var local_row_offset = 0u; local_row_offset < tile_size; local_row_offset += sub_tile_count) { + let b_global = b_global_base + local_row_offset + idy; + let k_offset = kidx / elements_in_value_b + idx; + #if !single_scale_weights + let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size; + #endif + #if has_full_k_tiles + { + #else + if (k_offset < uniforms.K_of_b) { + #endif + #if has_full_q_tiles + #if !single_scale_weights + let q_scale_b = q_output_element_t(q_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx)); + #endif + let q_weight = q_b.getByOffset(b_global * uniforms.K_of_b + k_offset); + q_inter_results[local_row_offset + idy][idx] += compute_projection_sum(q_weight, q_scale_b, idx); + #else + if (b_global < uniforms.Nq) { + #if !single_scale_weights + let q_scale_b = q_output_element_t(q_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx)); + #endif + let q_weight = q_b.getByOffset(b_global * uniforms.K_of_b + k_offset); + q_inter_results[local_row_offset + idy][idx] += compute_projection_sum(q_weight, q_scale_b, idx); + } + #endif + #if has_full_kv_tiles + #if !single_scale_weights + let k_scale_b = q_output_element_t(k_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx)); + let v_scale_b = q_output_element_t(v_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx)); + #endif + let k_weight = k_b.getByOffset(b_global * uniforms.K_of_b + k_offset); + let v_weight = v_b.getByOffset(b_global * uniforms.K_of_b + k_offset); + k_inter_results[local_row_offset + idy][idx] += compute_projection_sum(k_weight, k_scale_b, idx); + v_inter_results[local_row_offset + idy][idx] += compute_projection_sum(v_weight, v_scale_b, idx); + #else + if (b_global < uniforms.Nkv) { + #if !single_scale_weights + let k_scale_b = q_output_element_t(k_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx)); + let v_scale_b = q_output_element_t(v_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx)); + #endif + let k_weight = k_b.getByOffset(b_global * uniforms.K_of_b + k_offset); + let v_weight = v_b.getByOffset(b_global * uniforms.K_of_b + k_offset); + k_inter_results[local_row_offset + idy][idx] += compute_projection_sum(k_weight, k_scale_b, idx); + v_inter_results[local_row_offset + idy][idx] += compute_projection_sum(v_weight, v_scale_b, idx); + } + #endif + } + } + workgroupBarrier(); + } + +$MAIN { + let batch = workgroup_idx / uniforms.num_N_tile; + let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size; + + let idx = local_idx % tile_size_k_vec; + let idy = local_idx / tile_size_k_vec; + + if (local_idx < tile_size) { + for (var b = 0u; b < tile_size_k_vec; b++) { + q_inter_results[local_idx][b] = q_output_element_t(0); + k_inter_results[local_idx][b] = q_output_element_t(0); + v_inter_results[local_idx][b] = q_output_element_t(0); + } + } + + var sum_squared_local = 0.0; + for (var a_idx = local_idx; a_idx < uniforms.K_of_a; a_idx += workgroup_size_x) { + let a_value = load_merged_input(batch * uniforms.K_of_a + a_idx); +#if component_a == 1 + let a_f32 = f32(a_value); + sum_squared_local += a_f32 * a_f32; +#elif component_a == 2 + let a_f32 = vec2(a_value); + sum_squared_local += dot(a_f32, a_f32); +#elif component_a == 4 + let a_f32 = vec4(a_value); + sum_squared_local += dot(a_f32, a_f32); +#endif + } + sum_squared_shared[local_idx] = sum_squared_local; + workgroupBarrier(); + + var reduce_size : u32 = workgroup_size_x; + for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) { + reduce_size = curr_size + (reduce_size & 1u); + if (local_idx < curr_size) { + sum_squared_shared[local_idx] += sum_squared_shared[local_idx + reduce_size]; + } + workgroupBarrier(); + } + + let inv_std = inverseSqrt(sum_squared_shared[0] / f32(uniforms.K) + uniforms.epsilon); + +#if single_scale_weights + let q_scale_b = q_output_element_t(q_scales_b.getByOffset(0)); + let k_scale_b = q_output_element_t(k_scales_b.getByOffset(0)); + let v_scale_b = q_output_element_t(v_scales_b.getByOffset(0)); +#endif + +#if k_unroll_tiles == 1 + for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); + } +#elif k_unroll_tiles == 2 + let unrolled_k_step = tile_size_k * 2u; + let unrolled_k_limit = uniforms.K - (uniforms.K % unrolled_k_step); + for (var kidx = 0u; kidx < unrolled_k_limit; kidx += unrolled_k_step) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k, inv_std); + } + for (var kidx = unrolled_k_limit; kidx < uniforms.K; kidx += tile_size_k) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); + } +#elif k_unroll_tiles == 4 + let unrolled_k_step = tile_size_k * 4u; + let unrolled_k_limit = uniforms.K - (uniforms.K % unrolled_k_step); + for (var kidx = 0u; kidx < unrolled_k_limit; kidx += unrolled_k_step) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k, inv_std); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k * 2u, inv_std); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k * 3u, inv_std); + } + for (var kidx = unrolled_k_limit; kidx < uniforms.K; kidx += tile_size_k) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); + } +#endif + + if (local_idx < tile_size) { + let b_global = b_global_base + local_idx; + var q_output_value = q_output_element_t(0); + var k_output_value = q_output_element_t(0); + var v_output_value = q_output_element_t(0); + for (var b = 0u; b < tile_size_k_vec; b++) { + q_output_value += q_inter_results[local_idx][b]; + k_output_value += k_inter_results[local_idx][b]; + v_output_value += v_inter_results[local_idx][b]; + } +#if has_full_q_tiles + { +#else + if (b_global < uniforms.Nq) { +#endif + q_output.setByOffset(batch * uniforms.Nq + b_global, q_output_value_t(q_output_value)); + } +#if has_full_kv_tiles + { +#else + if (b_global < uniforms.Nkv) { +#endif + k_output.setByOffset(batch * uniforms.Nkv + b_global, k_output_value_t(k_output_value)); + v_output.setByOffset(batch * uniforms.Nkv + b_global, v_output_value_t(v_output_value)); + } + } +} diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc index 5f8eb9b92f836..86e8d0c0db964 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc @@ -24,19 +24,6 @@ constexpr unsigned int kMinMForTileOptimization = 4; constexpr uint32_t kFusedDecodeFastPathBits = 4u; constexpr uint32_t kFusedDecodeFastPathBlockSize = 32u; -bool CanApplyDP4AFusedDecodePath(const Tensor* y, - onnxruntime::webgpu::ComputeContext& context, - uint64_t accuracy_level, - uint32_t block_size, - uint32_t N, - uint32_t K) { - if (y->DataType() == DataTypeImpl::GetType()) { - return false; - } - - return CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, 4); -} - bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(const Tensor* a, int64_t K_op, int64_t N_op, @@ -104,19 +91,50 @@ bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(const Tensor* a, M >= kMinMForTileOptimization; } +bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t accuracy_level, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y) { + TensorShape b_shape({N_op, K_op}); + MatMulComputeHelper helper; + if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { + return false; + } + + const uint32_t M = onnxruntime::narrow(helper.M()); + const uint32_t N = onnxruntime::narrow(helper.N()); + const uint32_t K = onnxruntime::narrow(helper.K()); + const uint32_t block_size = onnxruntime::narrow(block_size_op); + const uint32_t components_a = GetMaxComponents(K); + + return ((M >= kMinMForTileOptimization) || + y->DataType() == DataTypeImpl::GetType() || + context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && + CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a); +} + class MatMulNBitsSiluMulDecodeProgram final : public Program { public: MatMulNBitsSiluMulDecodeProgram(uint32_t tile_size, bool has_gate_bias, bool has_up_bias, bool single_scale_weights, - uint32_t tile_size_k_vec) + uint32_t tile_size_k_vec, + uint32_t k_unroll_tiles, + bool has_full_n_tiles, + bool has_full_k_tiles) : Program{"MatMulNBitsSiluMulDecode"}, tile_size_(tile_size), has_gate_bias_(has_gate_bias), has_up_bias_(has_up_bias), single_scale_weights_(single_scale_weights), - tile_size_k_vec_(tile_size_k_vec) {} + tile_size_k_vec_(tile_size_k_vec), + k_unroll_tiles_(k_unroll_tiles), + has_full_n_tiles_(has_full_n_tiles), + has_full_k_tiles_(has_full_k_tiles) {} Status GenerateShaderCode(ShaderHelper& shader) const override { const auto& a = shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias); @@ -145,8 +163,11 @@ class MatMulNBitsSiluMulDecodeProgram final : public Program { +class MatMulNBitsSiluMulWideTileM1Program final : public Program { public: - DP4AMatMulNBitsSiluMulDecodeProgram(uint32_t tile_size_k_vec, - uint32_t tile_size, - bool has_gate_bias, + MatMulNBitsSiluMulWideTileM1Program(bool has_gate_bias, bool has_up_bias, - bool single_scale_weights) - : Program{"DP4AMatMulNBitsSiluMulDecode"}, - tile_size_k_vec_(tile_size_k_vec), - tile_size_(tile_size), + uint32_t outputs_per_thread) + : Program{"MatMulNBitsSiluMulWideTileM1Decode"}, has_gate_bias_(has_gate_bias), has_up_bias_(has_up_bias), - single_scale_weights_(single_scale_weights) {} + outputs_per_thread_(outputs_per_thread) {} Status GenerateShaderCode(ShaderHelper& shader) const override { - const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform); - const auto& scales_a = shader.AddInput("scales_a", ShaderUsage::UseUniform); - const auto& gate_b = shader.AddInput("gate_b", ShaderUsage::UseUniform); - const auto& gate_scales_b = shader.AddInput("gate_scales_b", ShaderUsage::UseUniform); - const auto& up_b = shader.AddInput("up_b", ShaderUsage::UseUniform); - const auto& up_scales_b = shader.AddInput("up_scales_b", ShaderUsage::UseUniform); + const auto& a = shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& gate_b = shader.AddInput("gate_b", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& gate_scales_b = shader.AddInput("gate_scales_b", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& up_b = shader.AddInput("up_b", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& up_scales_b = shader.AddInput("up_scales_b", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); if (has_gate_bias_) { shader.AddInput("gate_bias", ShaderUsage::UseUniform); } if (has_up_bias_) { shader.AddInput("up_bias", ShaderUsage::UseUniform); } - const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul_silu_mul.wgsl.template", + return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_silu_mul_wide_tile_m1.wgsl.template", WGSL_TEMPLATE_PARAMETER(has_gate_bias, has_gate_bias_), WGSL_TEMPLATE_PARAMETER(has_up_bias, has_up_bias_), - WGSL_TEMPLATE_PARAMETER(has_zero_points, false), - WGSL_TEMPLATE_PARAMETER(n_bits, 4), - WGSL_TEMPLATE_PARAMETER(output_type_i32, false), - WGSL_TEMPLATE_PARAMETER(single_scale_weights, single_scale_weights_), - WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), - WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec_), + WGSL_TEMPLATE_PARAMETER(outputs_per_thread, outputs_per_thread_), WGSL_TEMPLATE_VARIABLE(a, a), WGSL_TEMPLATE_VARIABLE(gate_b, gate_b), WGSL_TEMPLATE_VARIABLE(gate_scales_b, gate_scales_b), WGSL_TEMPLATE_VARIABLE(output, output), - WGSL_TEMPLATE_VARIABLE(scales_a, scales_a), WGSL_TEMPLATE_VARIABLE(up_b, up_b), WGSL_TEMPLATE_VARIABLE(up_scales_b, up_scales_b)); } @@ -228,105 +241,16 @@ class DP4AMatMulNBitsSiluMulDecodeProgram final : public Program(), a_quant_shape); - TensorShapeVector a_scales_dims({batch_count, 1, 1, K / kBlockSizeA}); - Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims); - - quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}}) - .AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), 1}, - {&a_scale, ProgramTensorMetadataDependency::Rank, 1}}) - .AddUniformVariable({batch_count * K / kU32Components}); - ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); - - const bool has_gate_bias = gate_bias != nullptr; - const bool has_up_bias = up_bias != nullptr; - const bool single_scale_weights = (block_size == K * N); - const uint32_t block_size_per_col = single_scale_weights ? K : block_size; - const uint32_t n_blocks_per_col = (K + block_size_per_col - 1) / block_size_per_col; - const uint32_t blob_size = (block_size_per_col / 8) * nbits; - const uint32_t blob_size_in_words = blob_size / 4; - const uint32_t components_b = GetMaxComponents(blob_size_in_words); - const uint32_t components_b_with_u32 = components_b * kU32Components; - - constexpr uint32_t workgroup_size = 128; - constexpr uint32_t tile_size = 4; - const uint32_t tile_size_k_vec = 32; - const uint32_t num_N_tile = CeilDiv(N, tile_size); - - DP4AMatMulNBitsSiluMulDecodeProgram program{tile_size_k_vec, - tile_size, - has_gate_bias, - has_up_bias, - single_scale_weights}; - program.SetWorkgroupSize(workgroup_size); - program.SetDispatchGroupSize(num_N_tile, 1, batch_count); - program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, - {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1}, - {gate_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, - {gate_scales, ProgramTensorMetadataDependency::TypeAndRank, 1}, - {up_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, - {up_scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1}) - .AddUniformVariables({batch_count, - N, - K, - K / 16, - K / 32, - block_size, - n_blocks_per_col, - num_N_tile}) - .CacheHint(tile_size_k_vec, tile_size, has_gate_bias, has_up_bias, single_scale_weights, "dp4a_decode_4bit"); - if (has_gate_bias) { - program.AddInput({gate_bias, ProgramTensorMetadataDependency::None}); - } - if (has_up_bias) { - program.AddInput({up_bias, ProgramTensorMetadataDependency::None}); - } - - return context.RunProgram(program); -} - class MatMulNBitsSiluMulProgram final : public Program { public: MatMulNBitsSiluMulProgram() : Program{"MatMulNBitsSiluMul"} {} @@ -379,6 +303,16 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& const uint32_t N = onnxruntime::narrow(helper.N()); const uint32_t K = onnxruntime::narrow(helper.K()); const uint32_t block_size = onnxruntime::narrow(block_size_); + const uint32_t components_a = GetMaxComponents(K); + const bool single_scale_weights = (block_size == K * N); + const uint32_t block_size_per_col = single_scale_weights ? K : block_size; + const uint32_t n_blocks_per_col = (K + block_size_per_col - 1) / block_size_per_col; + const uint32_t blob_size = (block_size_per_col / 8) * onnxruntime::narrow(bits_); + const uint32_t blob_size_in_words = blob_size / 4; + const uint32_t components_b = GetMaxComponents(blob_size_in_words); + constexpr uint32_t kU32Components = 4; + const uint32_t components_b_with_u32 = components_b * kU32Components; + const uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; Tensor* y = context.Output(0, output_shape); const uint32_t data_size = onnxruntime::narrow(y->Shape().Size()); @@ -395,38 +329,30 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& bits_, context, y); - const bool would_use_wide_tile_unfused = + const bool would_use_dp4a_unfused = + WouldApplyDP4AMatMulNBitsInCurrentDispatch(a, + K_, + N_, + block_size_, + accuracy_level_, + context, + y); + const bool would_use_wide_tile_unfused = WouldApplyWideTileMatMulNBitsInCurrentDispatch(a, K_, N_, block_size_, bits_); - /* - if (!would_use_subgroup_unfused && - M == 1 && bits_ == 4 && - CanApplyDP4AFusedDecodePath(y, context, accuracy_level_, block_size, N, K)) { - return ApplyDP4AFusedDecodeMatMulNBitsSiluMul(a, - gate_b, - gate_scales, - gate_bias, - up_b, - up_scales, - up_bias, - batch_count, - N, - K, - block_size, - onnxruntime::narrow(bits_), - context, - y); - } - */ + // The experimental wide M==1 fused path regressed badly on NVIDIA decode shapes. + // Keep the implementation around for future work, but do not dispatch to it. if (!would_use_subgroup_unfused && + !would_use_dp4a_unfused && !would_use_wide_tile_unfused && M == 1 && bits_ == kFusedDecodeFastPathBits && block_size == kFusedDecodeFastPathBlockSize) { + //ORT_ENFORCE(false, "The experimental wide M==1 fused path regressed badly on NVIDIA decode shapes. Keep the implementation around for future work, but do not dispatch to it."); ORT_ENFORCE(bits_ == kFusedDecodeFastPathBits, "MatMulNBitsSiluMulDecodeProgram is specialized for 4-bit weights only."); ORT_ENFORCE(block_size == kFusedDecodeFastPathBlockSize, @@ -434,27 +360,44 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& const bool has_gate_bias = gate_bias != nullptr; const bool has_up_bias = up_bias != nullptr; - const bool single_scale_weights = (block_size == K * N); - const uint32_t block_size_per_col = single_scale_weights ? K : block_size; - const uint32_t n_blocks_per_col = (K + block_size_per_col - 1) / block_size_per_col; - const uint32_t blob_size = (block_size_per_col / 8) * onnxruntime::narrow(bits_); - const uint32_t blob_size_in_words = blob_size / 4; - const uint32_t components_a = GetMaxComponents(K); - const uint32_t components_b = GetMaxComponents(blob_size_in_words); - constexpr uint32_t kU32Components = 4; - const uint32_t components_b_with_u32 = components_b * kU32Components; - const uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; - constexpr uint32_t workgroup_size = 128; - constexpr uint32_t tile_size = 8; - const uint32_t tile_size_k_vec = - (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u; + uint32_t workgroup_size = 128; + uint32_t tile_size = 8; + uint32_t tile_size_k_vec = + (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u; + + // For the smallest decode-like M==1 case, reduce K-split width and workgroup size + // so the generic fused kernel spends less time on reduction and barriers. + if (context.AdapterInfo().vendor != std::string_view{"intel"} && N <= 2048) { + workgroup_size = 64; + tile_size = 4; + tile_size_k_vec = 16; + } + + const uint32_t elements_in_value_b = components_b * (32u / onnxruntime::narrow(bits_)); + const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b; + const bool has_full_n_tiles = (N % tile_size) == 0; + const bool has_full_k_tiles = (K % tile_size_k) == 0; + const uint32_t k_tile_iterations = K / tile_size_k; + + uint32_t k_unroll_tiles = 1; + if (has_full_k_tiles) { + if (k_tile_iterations >= 8 && N <= 2048 && context.AdapterInfo().vendor != std::string_view{"intel"}) { + k_unroll_tiles = 4; + } else if (k_tile_iterations >= 4) { + k_unroll_tiles = 2; + } + } + const uint32_t num_N_tile = CeilDiv(N, tile_size); MatMulNBitsSiluMulDecodeProgram program{tile_size, has_gate_bias, has_up_bias, single_scale_weights, - tile_size_k_vec}; + tile_size_k_vec, + k_unroll_tiles, + has_full_n_tiles, + has_full_k_tiles}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize(num_N_tile, 1, batch_count); program @@ -472,7 +415,14 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& {n_blocks_per_col}, {num_N_tile}, {batch_count}}) - .CacheHint(single_scale_weights, has_gate_bias, has_up_bias, tile_size_k_vec, "decode_4bit"); + .CacheHint(single_scale_weights, + has_gate_bias, + has_up_bias, + tile_size_k_vec, + k_unroll_tiles, + has_full_n_tiles, + has_full_k_tiles, + "decode_4bit"); if (has_gate_bias) { program.AddInput({gate_bias, ProgramTensorMetadataDependency::None}); } @@ -486,6 +436,7 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor gate_output = context.CreateGPUTensor(a->DataType(), output_shape); Tensor up_output = context.CreateGPUTensor(a->DataType(), output_shape); + //ORT_ENFORCE(false, "Reached prefill."); ORT_RETURN_IF_ERROR(ApplyMatMulNBits(a, gate_b, gate_scales, nullptr, gate_bias, K_, N_, block_size_, accuracy_level_, bits_, context, &gate_output)); ORT_RETURN_IF_ERROR(ApplyMatMulNBits(a, up_b, up_scales, nullptr, up_bias, K_, N_, block_size_, accuracy_level_, bits_, context, &up_output)); diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul.wgsl.template index 3976267b0a75b..25e3b869d0c21 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul.wgsl.template @@ -5,6 +5,8 @@ #param component_a #param component_b #param elements_in_value_b +#param has_full_k_tiles +#param has_full_n_tiles #param single_scale_weights #param sub_tile_count #param tile_size_k_vec @@ -12,6 +14,7 @@ #param tile_size #param has_gate_bias #param has_up_bias +#param k_unroll_tiles #use .getByOffset .setByOffset @@ -24,11 +27,127 @@ const default_zero_point = output_element_t(8); fn loadSHMA(batch: u32, kidx: u32, col: u32) { let k_offset = kidx / component_a + col; - if (batch < uniforms.batch_count && k_offset < uniforms.K_of_a) { +#if has_full_k_tiles + tile_A[col] = a.getByOffset(batch * uniforms.K_of_a + k_offset); +#else + if (k_offset < uniforms.K_of_a) { tile_A[col] = a.getByOffset(batch * uniforms.K_of_a + k_offset); } else { tile_A[col] = input_a_value_t(0); } +#endif +} + +fn compute_gate_up_sums(b_global: u32, kidx: u32, idx: u32, k_offset: u32) -> vec2 { +#if !single_scale_weights + let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size; + let gate_scale_b = gate_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx); + let up_scale_b = up_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx); +#endif + let gate_b_value = gate_b.getByOffset(b_global * uniforms.K_of_b + k_offset); + let up_b_value = up_b.getByOffset(b_global * uniforms.K_of_b + k_offset); + + var gate_sum = output_element_t(0); + var up_sum = output_element_t(0); + var a_offset = idx * (8 / component_a) * component_b; +#if component_b == 1 + let gate_b_value_lower = vec4(unpack4xU8(gate_b_value & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let gate_b_value_upper = vec4(unpack4xU8((gate_b_value >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let gate_b0 = vec4(gate_b_value_lower[0], gate_b_value_upper[0], gate_b_value_lower[1], gate_b_value_upper[1]) * gate_scale_b; + let gate_b1 = vec4(gate_b_value_lower[2], gate_b_value_upper[2], gate_b_value_lower[3], gate_b_value_upper[3]) * gate_scale_b; + let up_b_value_lower = vec4(unpack4xU8(up_b_value & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let up_b_value_upper = vec4(unpack4xU8((up_b_value >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let up_b0 = vec4(up_b_value_lower[0], up_b_value_upper[0], up_b_value_lower[1], up_b_value_upper[1]) * up_scale_b; + let up_b1 = vec4(up_b_value_lower[2], up_b_value_upper[2], up_b_value_lower[3], up_b_value_upper[3]) * up_scale_b; +#if component_a == 1 + let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]); + let a1 = vec4(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]); + gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); + up_sum += dot(a0, up_b0) + dot(a1, up_b1); +#elif component_a == 2 + let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1]); + let a1 = vec4(tile_A[a_offset + 2], tile_A[a_offset + 3]); + gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); + up_sum += dot(a0, up_b0) + dot(a1, up_b1); +#elif component_a == 4 + let a0 = tile_A[a_offset]; + let a1 = tile_A[a_offset + 1]; + gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); + up_sum += dot(a0, up_b0) + dot(a1, up_b1); +#endif +#else + for (var i = 0u; i < component_b; i++) { + let gate_b_value_lower = vec4(unpack4xU8(gate_b_value[i] & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let gate_b_value_upper = vec4(unpack4xU8((gate_b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let gate_b0 = vec4(gate_b_value_lower[0], gate_b_value_upper[0], gate_b_value_lower[1], gate_b_value_upper[1]) * gate_scale_b; + let gate_b1 = vec4(gate_b_value_lower[2], gate_b_value_upper[2], gate_b_value_lower[3], gate_b_value_upper[3]) * gate_scale_b; + let up_b_value_lower = vec4(unpack4xU8(up_b_value[i] & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let up_b_value_upper = vec4(unpack4xU8((up_b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let up_b0 = vec4(up_b_value_lower[0], up_b_value_upper[0], up_b_value_lower[1], up_b_value_upper[1]) * up_scale_b; + let up_b1 = vec4(up_b_value_lower[2], up_b_value_upper[2], up_b_value_lower[3], up_b_value_upper[3]) * up_scale_b; +#if component_a == 1 + let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]); + let a1 = vec4(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]); + gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); + up_sum += dot(a0, up_b0) + dot(a1, up_b1); + a_offset += 8; +#elif component_a == 2 + let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1]); + let a1 = vec4(tile_A[a_offset + 2], tile_A[a_offset + 3]); + gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); + up_sum += dot(a0, up_b0) + dot(a1, up_b1); + a_offset += 4; +#elif component_a == 4 + let a0 = tile_A[a_offset]; + let a1 = tile_A[a_offset + 1]; + gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); + up_sum += dot(a0, up_b0) + dot(a1, up_b1); + a_offset += 2; +#endif + } +#endif + + return vec2(gate_sum, up_sum); +} + +fn process_k_tile(batch: u32, b_global_base: u32, thread_idx: u32, idx: u32, idy: u32, kidx: u32) { + for (var id = thread_idx; id < a_length_per_tile; id += workgroup_size_x) + { + loadSHMA(batch, kidx, id); + } + workgroupBarrier(); + + for (var local_row_offset = 0u; local_row_offset < tile_size; local_row_offset += sub_tile_count) + { + let b_global = b_global_base + local_row_offset + idy; + let k_offset = kidx / elements_in_value_b + idx; +#if has_full_n_tiles +#if !has_full_k_tiles + if (k_offset < uniforms.K_of_b) { +#endif +#else +#if has_full_k_tiles + if (b_global < uniforms.N) { +#else + if (b_global < uniforms.N && k_offset < uniforms.K_of_b) { +#endif +#endif + let sums = compute_gate_up_sums(b_global, kidx, idx, k_offset); + gate_inter_results[local_row_offset + idy][idx] += sums[0]; + up_inter_results[local_row_offset + idy][idx] += sums[1]; +#if has_full_n_tiles +#if !has_full_k_tiles + } +#endif +#else +#if has_full_k_tiles + } +#else + } +#endif +#endif + } + workgroupBarrier(); } $MAIN { @@ -52,94 +171,33 @@ $MAIN { let block_idx = 0u; #endif - for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k) - { - for (var id = local_idx; id < a_length_per_tile; id += workgroup_size_x) - { - loadSHMA(batch, kidx, id); - } - workgroupBarrier(); - - for (var local_row_offset = 0u; local_row_offset < tile_size; local_row_offset += sub_tile_count) - { - let b_global = b_global_base + local_row_offset + idy; - let k_offset = kidx / elements_in_value_b + idx; - if (b_global < uniforms.N && k_offset < uniforms.K_of_b) - { -#if !single_scale_weights - let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size; - let gate_scale_b = gate_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx); - let up_scale_b = up_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx); -#endif - let gate_b_value = gate_b.getByOffset(b_global * uniforms.K_of_b + k_offset); - let up_b_value = up_b.getByOffset(b_global * uniforms.K_of_b + k_offset); - - var gate_sum = output_element_t(0); - var up_sum = output_element_t(0); - var a_offset = idx * (8 / component_a) * component_b; -#if component_b == 1 - let gate_b_value_lower = vec4(unpack4xU8(gate_b_value & 0x0F0F0F0Fu)) - vec4(default_zero_point); - let gate_b_value_upper = vec4(unpack4xU8((gate_b_value >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point); - let gate_b0 = vec4(gate_b_value_lower[0], gate_b_value_upper[0], gate_b_value_lower[1], gate_b_value_upper[1]) * gate_scale_b; - let gate_b1 = vec4(gate_b_value_lower[2], gate_b_value_upper[2], gate_b_value_lower[3], gate_b_value_upper[3]) * gate_scale_b; - let up_b_value_lower = vec4(unpack4xU8(up_b_value & 0x0F0F0F0Fu)) - vec4(default_zero_point); - let up_b_value_upper = vec4(unpack4xU8((up_b_value >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point); - let up_b0 = vec4(up_b_value_lower[0], up_b_value_upper[0], up_b_value_lower[1], up_b_value_upper[1]) * up_scale_b; - let up_b1 = vec4(up_b_value_lower[2], up_b_value_upper[2], up_b_value_lower[3], up_b_value_upper[3]) * up_scale_b; -#if component_a == 1 - let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]); - let a1 = vec4(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]); - gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); - up_sum += dot(a0, up_b0) + dot(a1, up_b1); -#elif component_a == 2 - let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1]); - let a1 = vec4(tile_A[a_offset + 2], tile_A[a_offset + 3]); - gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); - up_sum += dot(a0, up_b0) + dot(a1, up_b1); -#elif component_a == 4 - let a0 = tile_A[a_offset]; - let a1 = tile_A[a_offset + 1]; - gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); - up_sum += dot(a0, up_b0) + dot(a1, up_b1); -#endif -#else - for (var i = 0u; i < component_b; i++) { - let gate_b_value_lower = vec4(unpack4xU8(gate_b_value[i] & 0x0F0F0F0Fu)) - vec4(default_zero_point); - let gate_b_value_upper = vec4(unpack4xU8((gate_b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point); - let gate_b0 = vec4(gate_b_value_lower[0], gate_b_value_upper[0], gate_b_value_lower[1], gate_b_value_upper[1]) * gate_scale_b; - let gate_b1 = vec4(gate_b_value_lower[2], gate_b_value_upper[2], gate_b_value_lower[3], gate_b_value_upper[3]) * gate_scale_b; - let up_b_value_lower = vec4(unpack4xU8(up_b_value[i] & 0x0F0F0F0Fu)) - vec4(default_zero_point); - let up_b_value_upper = vec4(unpack4xU8((up_b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point); - let up_b0 = vec4(up_b_value_lower[0], up_b_value_upper[0], up_b_value_lower[1], up_b_value_upper[1]) * up_scale_b; - let up_b1 = vec4(up_b_value_lower[2], up_b_value_upper[2], up_b_value_lower[3], up_b_value_upper[3]) * up_scale_b; -#if component_a == 1 - let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]); - let a1 = vec4(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]); - gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); - up_sum += dot(a0, up_b0) + dot(a1, up_b1); - a_offset += 8; -#elif component_a == 2 - let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1]); - let a1 = vec4(tile_A[a_offset + 2], tile_A[a_offset + 3]); - gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); - up_sum += dot(a0, up_b0) + dot(a1, up_b1); - a_offset += 4; -#elif component_a == 4 - let a0 = tile_A[a_offset]; - let a1 = tile_A[a_offset + 1]; - gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); - up_sum += dot(a0, up_b0) + dot(a1, up_b1); - a_offset += 2; -#endif - } -#endif - - gate_inter_results[local_row_offset + idy][idx] += gate_sum; - up_inter_results[local_row_offset + idy][idx] += up_sum; - } - } - workgroupBarrier(); +#if k_unroll_tiles == 1 + for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx); + } +#elif k_unroll_tiles == 2 + let unrolled_k_step = tile_size_k * 2u; + let unrolled_k_limit = uniforms.K - (uniforms.K % unrolled_k_step); + for (var kidx = 0u; kidx < unrolled_k_limit; kidx += unrolled_k_step) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k); + } + for (var kidx = unrolled_k_limit; kidx < uniforms.K; kidx += tile_size_k) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx); + } +#elif k_unroll_tiles == 4 + let unrolled_k_step = tile_size_k * 4u; + let unrolled_k_limit = uniforms.K - (uniforms.K % unrolled_k_step); + for (var kidx = 0u; kidx < unrolled_k_limit; kidx += unrolled_k_step) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k * 2u); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k * 3u); + } + for (var kidx = unrolled_k_limit; kidx < uniforms.K; kidx += tile_size_k) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx); } +#endif if (batch >= uniforms.batch_count) { return; @@ -154,7 +212,11 @@ $MAIN { } let b_global = b_global_base + local_idx; let output_idx = batch * uniforms.N + b_global; +#if has_full_n_tiles + { +#else if (b_global < uniforms.N) { +#endif #if has_gate_bias gate_output_value += gate_bias[b_global]; #endif diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul_wide_tile_m1.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul_wide_tile_m1.wgsl.template new file mode 100644 index 0000000000000..47c292ba0ccf1 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul_wide_tile_m1.wgsl.template @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param has_gate_bias +#param has_up_bias +#param outputs_per_thread + +#use .getByOffset .setByOffset + +const KAVecSizeForBlock32 = 8u; +const kTileN : u32 = workgroup_size_x * outputs_per_thread; +const kDefaultZeroPoint = output_element_t(8); + +var a_data_tile : array; + +fn load_a(batch : u32, col : u32) -> input_a_value_t { + if (batch < uniforms.batch_count && col < uniforms.K_of_a) { + let offset = batch * uniforms.K_of_a + col; + return a.getByOffset(offset); + } + + return input_a_value_t(); +} + +fn load_gate_b(row : u32, block_idx : u32) -> vec4 { + if (row < uniforms.N && block_idx < uniforms.K_of_b) { + let offset = row * uniforms.K_of_b + block_idx; + return gate_b.getByOffset(offset); + } + + return vec4(); +} + +fn load_up_b(row : u32, block_idx : u32) -> vec4 { + if (row < uniforms.N && block_idx < uniforms.K_of_b) { + let offset = row * uniforms.K_of_b + block_idx; + return up_b.getByOffset(offset); + } + + return vec4(); +} + +fn dequantize_u4_block(packed_data : u32, + scale : output_element_t) -> mat2x4 { + let lower : vec4 = unpack4xU8(packed_data & 0x0F0F0F0Fu); + let upper : vec4 = unpack4xU8((packed_data >> 4u) & 0x0F0F0F0Fu); + + let zero_matrix : mat2x4 = mat2x4( + kDefaultZeroPoint, kDefaultZeroPoint, kDefaultZeroPoint, kDefaultZeroPoint, + kDefaultZeroPoint, kDefaultZeroPoint, kDefaultZeroPoint, kDefaultZeroPoint); + + var dequantized_values : mat2x4 = mat2x4( + output_element_t(lower[0]), output_element_t(upper[0]), + output_element_t(lower[1]), output_element_t(upper[1]), + output_element_t(lower[2]), output_element_t(upper[2]), + output_element_t(lower[3]), output_element_t(upper[3])); + + dequantized_values = (dequantized_values - zero_matrix) * scale; + return dequantized_values; +} + +$MAIN { + let batch = workgroup_id.z; + let col_base = workgroup_id.x * kTileN + local_idx; + + var gate_results : array; + var up_results : array; + for (var output_idx = 0u; output_idx < outputs_per_thread; output_idx++) { + gate_results[output_idx] = output_element_t(0); + up_results[output_idx] = output_element_t(0); + } + + for (var block_idx = 0u; block_idx < uniforms.n_blocks_per_col; block_idx++) { + if (local_idx < KAVecSizeForBlock32) { + a_data_tile[local_idx] = load_a(batch, block_idx * KAVecSizeForBlock32 + local_idx); + } + workgroupBarrier(); + + for (var output_idx = 0u; output_idx < outputs_per_thread; output_idx++) { + let col = col_base + output_idx * workgroup_size_x; + if (col < uniforms.N) { + let gate_scale = gate_scales_b.getByOffset(col * uniforms.n_blocks_per_col + block_idx); + let up_scale = up_scales_b.getByOffset(col * uniforms.n_blocks_per_col + block_idx); + let gate_b_data = load_gate_b(col, block_idx); + let up_b_data = load_up_b(col, block_idx); + + for (var b_idx = 0u; b_idx < 4u; b_idx++) { + let gate_dequantized = dequantize_u4_block(gate_b_data[b_idx], gate_scale); + let up_dequantized = dequantize_u4_block(up_b_data[b_idx], up_scale); + let a_data0 = a_data_tile[b_idx * 2u]; + let a_data1 = a_data_tile[b_idx * 2u + 1u]; + + gate_results[output_idx] += dot(a_data0, gate_dequantized[0]) + + dot(a_data1, gate_dequantized[1]); + up_results[output_idx] += dot(a_data0, up_dequantized[0]) + + dot(a_data1, up_dequantized[1]); + } + } + } + + workgroupBarrier(); + } + + if (batch >= uniforms.batch_count) { + return; + } + + for (var output_idx = 0u; output_idx < outputs_per_thread; output_idx++) { + let col = col_base + output_idx * workgroup_size_x; + if (col >= uniforms.N) { + continue; + } + + var gate_result = gate_results[output_idx]; + var up_result = up_results[output_idx]; +#if has_gate_bias + gate_result += gate_bias[col]; +#endif +#if has_up_bias + up_result += up_bias[col]; +#endif + + let one = output_element_t(1.0); + let silu_value = gate_result * (one / (one + exp(-gate_result))); + output.setByOffset(batch * uniforms.N + col, silu_value * up_result); + } +} // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 9389be885fbdf..05d2a4b4d6a37 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -22,6 +22,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gr // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBitsQKVSimplifiedLayerNorm); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBitsSiluMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu); @@ -51,6 +52,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry, bool enable BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 869c76dc15a2a..a1aba7b9475f1 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3677,6 +3677,85 @@ derived from the runtime shape of A and the shared attributes K and N. } }); + static const char* MatMulNBitsQKVSimplifiedLayerNorm_ver1_doc = R"DOC( +MatMulNBitsQKVSimplifiedLayerNorm fuses either SimplifiedLayerNormalization (RMSNorm) +or SkipSimplifiedLayerNormalization with three MatMulNBits projections that share the +same normalized activation. + + A_norm = SimplifiedLayerNormalization(A, norm_scale, epsilon) + Q = MatMulNBits(A_norm, q_weight) + K = MatMulNBits(A_norm, k_weight) + V = MatMulNBits(A_norm, v_weight) + +If skip is provided, the operator computes the SkipSimplifiedLayerNormalization variant +and may also return the input+skip residual sum as output 3. + +This operator is intended as a decode-oriented QKV fusion primitive. +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBitsQKVSimplifiedLayerNorm) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(MatMulNBitsQKVSimplifiedLayerNorm_ver1_doc) + .Attr("K", "Input feature dimension shared by the normalized input and all projection weights.", AttributeProto::INT) + .Attr("Nq", "Output feature dimension of the Q projection.", AttributeProto::INT) + .Attr("Nkv", "Output feature dimension shared by the K and V projections.", AttributeProto::INT) + .Attr("bits", "Bit-width used to quantize all weight matrices (valid range: 2~8)", AttributeProto::INT, static_cast(4)) + .Attr("block_size", + "Size of each quantization block along the K dimension. Must be a power of two and >= 16.", + AttributeProto::INT) + .Attr("accuracy_level", + "The minimum accuracy level of input A. It follows the same semantics as MatMulNBits.", + AttributeProto::INT, static_cast(0)) + .Attr("epsilon", "Epsilon used by the simplified layer norm reduction.", AttributeProto::FLOAT, 1e-6f) + .Input(0, "A", "The shared input tensor.", "T1") + .Input(1, "skip", "Optional residual input for SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional) + .Input(2, "norm_scale", "Scale input for the simplified layer norm with shape [K].", "T1") + .Input(3, "q_B", "Packed uint8 tensor for the Q projection weights.", "T2") + .Input(4, "q_scales", "Per-block scaling factors for the Q projection.", "T1") + .Input(5, "k_B", "Packed uint8 tensor for the K projection weights.", "T2") + .Input(6, "k_scales", "Per-block scaling factors for the K projection.", "T1") + .Input(7, "v_B", "Packed uint8 tensor for the V projection weights.", "T2") + .Input(8, "v_scales", "Per-block scaling factors for the V projection.", "T1") + .Output(0, "Q", "The Q projection output tensor.", "T1") + .Output(1, "K", "The K projection output tensor.", "T1") + .Output(2, "V", "The V projection output tensor.", "T1") + .Output(3, "input_skip_bias_sum", "Optional residual-sum output for SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional) + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors.") + .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + for (int output_index = 0; output_index < ctx.getNumOutputs(); ++output_index) { + propagateElemTypeFromInputToOutput(ctx, 0, output_index); + } + + if (!hasInputShape(ctx, 0)) { + return; + } + + const auto& input_shape = getInputShape(ctx, 0); + if (input_shape.dim_size() == 0) { + fail_shape_inference("A must have rank >= 1"); + } + + const int64_t q_out_features = getAttribute(ctx, "Nq", -1); + const int64_t kv_out_features = getAttribute(ctx, "Nkv", -1); + + auto set_output_shape = [&](int output_index, int64_t out_features) { + auto* output_shape = getOutputShape(ctx, output_index); + *output_shape = input_shape; + output_shape->mutable_dim(output_shape->dim_size() - 1)->set_dim_value(out_features); + }; + + set_output_shape(0, q_out_features); + set_output_shape(1, kv_out_features); + set_output_shape(2, kv_out_features); + if (ctx.getNumOutputs() > 3) { + auto* output_shape = getOutputShape(ctx, 3); + *output_shape = input_shape; + } + }); + static const char* MatMulBnb4_ver1_doc = R"DOC( MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 7f342ec809512..46c2dcec1543a 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -56,6 +56,7 @@ #include "core/optimizer/layer_norm_fusion.h" #include "core/optimizer/matmul_activation_fusion.h" #include "core/optimizer/matmul_add_fusion.h" +#include "core/optimizer/matmul_nbits_qkv_sln_fusion.h" #include "core/optimizer/matmul_nbits_silu_fusion.h" #include "core/optimizer/matmul_bn_fusion.h" #include "core/optimizer/matmul_integer_to_float.h" @@ -104,10 +105,18 @@ namespace onnxruntime::optimizer_utils { namespace { constexpr const char* kOrtEnableMatMulNBitsSiluFusionEnvVar = "ORT_ENABLE_MATMUL_NBITS_SILU_FUSION"; +constexpr const char* kOrtEnableMatMulNBitsQKVSimplifiedLayerNormFusionEnvVar = + "ORT_ENABLE_MATMUL_NBITS_QKV_SIMPLIFIED_LAYER_NORM_FUSION"; #if !defined(ORT_MINIMAL_BUILD) bool IsMatMulNBitsSiluFusionEnabled() { - return ParseEnvironmentVariableWithDefault(kOrtEnableMatMulNBitsSiluFusionEnvVar, 0) == 1; + return ParseEnvironmentVariableWithDefault(kOrtEnableMatMulNBitsSiluFusionEnvVar, 1) == 1; + //return true; +} + +bool IsMatMulNBitsQKVSimplifiedLayerNormFusionEnabled() { + return ParseEnvironmentVariableWithDefault(kOrtEnableMatMulNBitsQKVSimplifiedLayerNormFusionEnvVar, 1) == 1; + //return true; } #endif @@ -450,10 +459,14 @@ InlinedVector> GenerateTransformers( #endif transformers.emplace_back(std::make_unique(cpu_ep)); - if (IsMatMulNBitsSiluFusionEnabled()) { - transformers.emplace_back(std::make_unique( - InlinedHashSet{onnxruntime::kWebGpuExecutionProvider})); - } + if (IsMatMulNBitsSiluFusionEnabled()) { + transformers.emplace_back(std::make_unique( + InlinedHashSet{onnxruntime::kWebGpuExecutionProvider})); + } + if (IsMatMulNBitsQKVSimplifiedLayerNormFusionEnabled()) { + transformers.emplace_back(std::make_unique( + InlinedHashSet{onnxruntime::kWebGpuExecutionProvider})); + } #endif // !defined(DISABLE_CONTRIB_OPS) // The QDQFinalCleanupTransformer must run AFTER other transformers that fuse Q/DQ nodes. Otherwise, their diff --git a/onnxruntime/core/optimizer/matmul_nbits_qkv_sln_fusion.cc b/onnxruntime/core/optimizer/matmul_nbits_qkv_sln_fusion.cc new file mode 100644 index 0000000000000..1e8bc2346c8d7 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_nbits_qkv_sln_fusion.cc @@ -0,0 +1,289 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/matmul_nbits_qkv_sln_fusion.h" + +#include +#include + +#include "core/graph/graph_utils.h" +#include "core/graph/node_attr_utils.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +namespace { + +bool HasInput(const Node& node, size_t index) { + return index < node.InputDefs().size() && node.InputDefs()[index] != nullptr && !node.InputDefs()[index]->Name().empty(); +} + +bool IsSupportedSimplifiedLayerNormalization(const Node& node) { + return graph_utils::IsSupportedOptypeVersionAndDomain(node, "SimplifiedLayerNormalization", {1}); +} + +bool IsSupportedSkipSimplifiedLayerNormalization(const Node& node) { + return graph_utils::IsSupportedOptypeVersionAndDomain(node, "SkipSimplifiedLayerNormalization", {1}, kMSDomain); +} + +bool IsSupportedNormForFusion(const Node& node) { + return IsSupportedSimplifiedLayerNormalization(node) || IsSupportedSkipSimplifiedLayerNormalization(node); +} + +bool HasProducedOutput(const Node& node, size_t index) { + return index < node.OutputDefs().size() && node.OutputDefs()[index] != nullptr && !node.OutputDefs()[index]->Name().empty(); +} + +bool IsMatMulNBitsWithoutOptionalInputs(const Node& node) { + return graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMulNBits", {1}, kMSDomain) && + !HasInput(node, 3) && !HasInput(node, 4) && !HasInput(node, 5); +} + +int64_t GetIntAttr(const Node& node, const char* name, int64_t default_value, bool required = false) { + const auto* attr = graph_utils::GetNodeAttribute(node, name); + if (attr == nullptr) { + ORT_ENFORCE(!required, "Missing required attribute ", name, " on node ", node.Name()); + return default_value; + } + + return attr->i(); +} + +float GetFloatAttr(const Node& node, const char* name, float default_value) { + const auto* attr = graph_utils::GetNodeAttribute(node, name); + return attr == nullptr ? default_value : attr->f(); +} + +struct QkvNodes { + const Node* q = nullptr; + const Node* k = nullptr; + const Node* v = nullptr; +}; + +std::optional GetQkvNodes(const Graph& graph, const Node& norm) { + if (!HasProducedOutput(norm, 0) || graph.NodeProducesGraphOutput(norm)) { + return std::nullopt; + } + + std::array consumers{}; + size_t consumer_index = 0; + for (auto edge_it = norm.OutputEdgesBegin(); edge_it != norm.OutputEdgesEnd(); ++edge_it) { + if (edge_it->GetSrcArgIndex() != 0) { + continue; + } + + if (consumer_index >= consumers.size()) { + return std::nullopt; + } + + if (edge_it->GetDstArgIndex() != 0) { + return std::nullopt; + } + + const Node* consumer = graph.GetNode(edge_it->GetNode().Index()); + if (consumer == nullptr || !IsMatMulNBitsWithoutOptionalInputs(*consumer)) { + return std::nullopt; + } + + consumers[consumer_index++] = consumer; + } + + if (consumer_index != consumers.size()) { + return std::nullopt; + } + + const int64_t n0 = GetIntAttr(*consumers[0], "N", -1, true); + const int64_t n1 = GetIntAttr(*consumers[1], "N", -1, true); + const int64_t n2 = GetIntAttr(*consumers[2], "N", -1, true); + + QkvNodes qkv; + if (n0 != n1 && n1 == n2) { + qkv = {consumers[0], consumers[1], consumers[2]}; + } else if (n1 != n0 && n0 == n2) { + qkv = {consumers[1], consumers[0], consumers[2]}; + } else if (n2 != n0 && n0 == n1) { + qkv = {consumers[2], consumers[0], consumers[1]}; + } else { + return std::nullopt; + } + + return qkv; +} + +bool HasSupportedExecutionProvider(const Node& node) { + const auto& node_ep = node.GetExecutionProviderType(); + return node_ep.empty() || node_ep == kWebGpuExecutionProvider; +} + +bool IsFuseCandidate(const Node& norm, const QkvNodes& qkv) { + if (!IsSupportedNormForFusion(norm) || qkv.q == nullptr || qkv.k == nullptr || qkv.v == nullptr) { + return false; + } + + if (!HasSupportedExecutionProvider(norm) || !HasSupportedExecutionProvider(*qkv.q) || + !HasSupportedExecutionProvider(*qkv.k) || !HasSupportedExecutionProvider(*qkv.v)) { + return false; + } + + const size_t min_norm_inputs = IsSupportedSkipSimplifiedLayerNormalization(norm) ? 3u : 2u; + if (norm.InputDefs().size() < min_norm_inputs || qkv.q->InputDefs().empty() || qkv.k->InputDefs().empty() || qkv.v->InputDefs().empty()) { + return false; + } + + if (qkv.q->InputDefs()[0] != norm.OutputDefs()[0] || qkv.k->InputDefs()[0] != norm.OutputDefs()[0] || + qkv.v->InputDefs()[0] != norm.OutputDefs()[0]) { + return false; + } + + const int64_t q_k = GetIntAttr(*qkv.q, "K", -1, true); + const int64_t k_k = GetIntAttr(*qkv.k, "K", -1, true); + const int64_t v_k = GetIntAttr(*qkv.v, "K", -1, true); + const int64_t q_bits = GetIntAttr(*qkv.q, "bits", 4); + const int64_t k_bits = GetIntAttr(*qkv.k, "bits", 4); + const int64_t v_bits = GetIntAttr(*qkv.v, "bits", 4); + const int64_t q_block_size = GetIntAttr(*qkv.q, "block_size", -1, true); + const int64_t k_block_size = GetIntAttr(*qkv.k, "block_size", -1, true); + const int64_t v_block_size = GetIntAttr(*qkv.v, "block_size", -1, true); + const int64_t q_accuracy_level = GetIntAttr(*qkv.q, "accuracy_level", 0); + const int64_t k_accuracy_level = GetIntAttr(*qkv.k, "accuracy_level", 0); + const int64_t v_accuracy_level = GetIntAttr(*qkv.v, "accuracy_level", 0); + + return q_k == k_k && q_k == v_k && + q_bits == k_bits && q_bits == v_bits && + q_block_size == k_block_size && q_block_size == v_block_size && + q_accuracy_level == k_accuracy_level && q_accuracy_level == v_accuracy_level; +} + +} // namespace + +Status MatMulNBitsQKVSimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) { + auto* node_ptr = graph.GetNode(node_index); + if (node_ptr == nullptr) { + continue; + } + + auto& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (!IsSupportedNormForFusion(node)) { + continue; + } + + const auto qkv_nodes = GetQkvNodes(graph, node); + if (!qkv_nodes || !IsFuseCandidate(node, *qkv_nodes)) { + continue; + } + + const int64_t K = GetIntAttr(*qkv_nodes->q, "K", -1, true); + const int64_t Nq = GetIntAttr(*qkv_nodes->q, "N", -1, true); + const int64_t Nkv = GetIntAttr(*qkv_nodes->k, "N", -1, true); + const int64_t bits = GetIntAttr(*qkv_nodes->q, "bits", 4); + const int64_t block_size = GetIntAttr(*qkv_nodes->q, "block_size", -1, true); + const int64_t accuracy_level = GetIntAttr(*qkv_nodes->q, "accuracy_level", 0); + const float epsilon = GetFloatAttr(node, "epsilon", 1e-6f); + + const bool is_skip_sln = IsSupportedSkipSimplifiedLayerNormalization(node); + + LOGS(logger, INFO) << "MatMulNBitsQKVSimplifiedLayerNormFusion: matched norm='" << node.Name() + << "' q='" << qkv_nodes->q->Name() << "' k='" << qkv_nodes->k->Name() + << "' v='" << qkv_nodes->v->Name() << "' attrs={K=" << K + << ", Nq=" << Nq << ", Nkv=" << Nkv << ", bits=" << bits + << ", block_size=" << block_size << ", accuracy_level=" << accuracy_level + << ", epsilon=" << epsilon << ", skip_sln=" << is_skip_sln << "}"; + + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("K", K), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("Nq", Nq), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("Nkv", Nkv), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("bits", bits), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("epsilon", epsilon), attrs); + + NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr); + + InlinedVector fused_inputs{ + const_cast(node.InputDefs()[0]), + is_skip_sln ? const_cast(node.InputDefs()[1]) : &empty_arg, + const_cast(node.InputDefs()[is_skip_sln ? 2 : 1]), + const_cast(qkv_nodes->q->InputDefs()[1]), + const_cast(qkv_nodes->q->InputDefs()[2]), + const_cast(qkv_nodes->k->InputDefs()[1]), + const_cast(qkv_nodes->k->InputDefs()[2]), + const_cast(qkv_nodes->v->InputDefs()[1]), + const_cast(qkv_nodes->v->InputDefs()[2]), + }; + + InlinedVector fused_outputs{ + const_cast(qkv_nodes->q->OutputDefs()[0]), + const_cast(qkv_nodes->k->OutputDefs()[0]), + const_cast(qkv_nodes->v->OutputDefs()[0]), + }; + if (is_skip_sln && HasProducedOutput(node, 3)) { + fused_outputs.push_back(const_cast(node.OutputDefs()[3])); + } + + const auto norm_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(node); + const auto q_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(*qkv_nodes->q); + const auto k_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(*qkv_nodes->k); + const auto v_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(*qkv_nodes->v); + const auto norm_output_edges = is_skip_sln && HasProducedOutput(node, 3) + ? graph_utils::GraphEdge::GetNodeOutputEdges(node) + : std::vector{}; + + graph_utils::RemoveNodeOutputEdges(graph, const_cast(*qkv_nodes->q)); + graph.RemoveNode(qkv_nodes->q->Index()); + graph_utils::RemoveNodeOutputEdges(graph, const_cast(*qkv_nodes->k)); + graph.RemoveNode(qkv_nodes->k->Index()); + graph_utils::RemoveNodeOutputEdges(graph, const_cast(*qkv_nodes->v)); + graph.RemoveNode(qkv_nodes->v->Index()); + graph_utils::RemoveNodeOutputEdges(graph, node); + graph.RemoveNode(node.Index()); + + Node& fused_node = graph.AddNode(graph.GenerateNodeName("MatMulNBitsQKVSimplifiedLayerNorm"), + "MatMulNBitsQKVSimplifiedLayerNorm", + "fused SimplifiedLayerNormalization with Q/K/V MatMulNBits projections", + fused_inputs, + fused_outputs, + &attrs, + kMSDomain); + fused_node.SetExecutionProviderType(kWebGpuExecutionProvider); + + for (const auto& input_edge : norm_input_edges) { + int fused_input_index = input_edge.dst_arg_index; + if (!is_skip_sln && input_edge.dst_arg_index == 1) { + fused_input_index = 2; + } + + graph.AddEdge(input_edge.src_node, fused_node.Index(), input_edge.src_arg_index, fused_input_index); + } + + for (const auto& output_edge : q_output_edges) { + graph.AddEdge(fused_node.Index(), output_edge.dst_node, 0, output_edge.dst_arg_index); + } + for (const auto& output_edge : k_output_edges) { + graph.AddEdge(fused_node.Index(), output_edge.dst_node, 1, output_edge.dst_arg_index); + } + for (const auto& output_edge : v_output_edges) { + graph.AddEdge(fused_node.Index(), output_edge.dst_node, 2, output_edge.dst_arg_index); + } + if (is_skip_sln && HasProducedOutput(node, 3)) { + for (const auto& output_edge : norm_output_edges) { + if (output_edge.src_arg_index == 3) { + graph.AddEdge(fused_node.Index(), output_edge.dst_node, 3, output_edge.dst_arg_index); + } + } + } + + modified = true; + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/matmul_nbits_qkv_sln_fusion.h b/onnxruntime/core/optimizer/matmul_nbits_qkv_sln_fusion.h new file mode 100644 index 0000000000000..a40cc98459818 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_nbits_qkv_sln_fusion.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +class MatMulNBitsQKVSimplifiedLayerNormFusion : public GraphTransformer { + public: + explicit MatMulNBitsQKVSimplifiedLayerNormFusion( + const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("MatMulNBitsQKVSimplifiedLayerNormFusion", compatible_execution_providers) {} + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index e4376476a885d..a4654182d3b68 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "core/common/common.h" #include "core/common/logging/logging.h" @@ -13,6 +14,72 @@ namespace onnxruntime { namespace webgpu { +namespace { + +const char* CompilationMessageTypeToString(wgpu::CompilationMessageType type) { + switch (type) { + case wgpu::CompilationMessageType::Error: + return "error"; + case wgpu::CompilationMessageType::Warning: + return "warning"; + case wgpu::CompilationMessageType::Info: + return "info"; + default: + return "unknown"; + } +} + +std::string GetShaderCompilationDiagnostics(WebGpuContext& webgpu_context, const wgpu::ShaderModule& shader_module) { + struct CompilationInfoContext { + std::string diagnostics; + } compilation_info_context; + + auto future = shader_module.GetCompilationInfo( + wgpu::CallbackMode::WaitAnyOnly, + [](wgpu::CompilationInfoRequestStatus status, const wgpu::CompilationInfo* compilation_info, CompilationInfoContext* context) { + if (status != wgpu::CompilationInfoRequestStatus::Success) { + context->diagnostics = std::string{"Shader compilation info unavailable. Request status: "} + + (status == wgpu::CompilationInfoRequestStatus::CallbackCancelled ? "callback cancelled" : "unknown"); + return; + } + + if (compilation_info == nullptr || compilation_info->messageCount == 0 || compilation_info->messages == nullptr) { + return; + } + + std::string diagnostics; + diagnostics.reserve(compilation_info->messageCount * 96); + for (size_t i = 0; i < compilation_info->messageCount; ++i) { + const auto& message = compilation_info->messages[i]; + diagnostics += "\n ["; + diagnostics += CompilationMessageTypeToString(message.type); + diagnostics += "]"; + if (message.lineNum > 0) { + diagnostics += " line "; + diagnostics += std::to_string(message.lineNum); + if (message.linePos > 0) { + diagnostics += ':'; + diagnostics += std::to_string(message.linePos); + } + } + diagnostics += ": "; + diagnostics += std::string_view{message.message}; + } + + context->diagnostics = std::move(diagnostics); + }, + &compilation_info_context); + + const Status wait_status = webgpu_context.Wait(future); + if (!wait_status.IsOK()) { + return std::string{"Shader compilation info wait failed: "} + wait_status.ErrorMessage(); + } + + return compilation_info_context.diagnostics; +} + +} // namespace + ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniform_ranks) : name{program.Name()}, compute_pipeline{compute_pipeline}, @@ -197,7 +264,7 @@ Status ProgramManager::Build(const ProgramBase& program, struct CreateComputePipelineContext { wgpu::ComputePipeline& pipeline; - Status status; + std::string error_message; } create_pipeline_context{compute_pipeline, {}}; ORT_RETURN_IF_ERROR( @@ -209,12 +276,23 @@ Status ProgramManager::Build(const ProgramBase& program, if (status == wgpu::CreatePipelineAsyncStatus::Success) { context->pipeline = std::move(pipeline); } else { - context->status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create a WebGPU compute pipeline: ", std::string_view{message}); + context->error_message = "Failed to create a WebGPU compute pipeline: "; + context->error_message.append(message.data, message.length); } }, &create_pipeline_context))); - return create_pipeline_context.status; + if (create_pipeline_context.error_message.empty()) { + return Status::OK(); + } + + const std::string compilation_diagnostics = GetShaderCompilationDiagnostics(webgpu_context_, shader_module); + if (compilation_diagnostics.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, create_pipeline_context.error_message); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, create_pipeline_context.error_message, + "\nShader compilation diagnostics:", compilation_diagnostics); } const ProgramArtifact* ProgramManager::Get(const std::string& key) const { diff --git a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc index 4d6b1a5791fbc..193e01c4c6aa8 100644 --- a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc +++ b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -128,6 +129,36 @@ struct MlpTrafficStats { double total_bytes; }; +struct QkvDecodeBenchConfig { + int64_t q_n; + int64_t kv_n; + int64_t k; + int64_t bits; + int64_t block_size; + int64_t accuracy_level; +}; + +enum class QkvDecodeBenchmarkVariant { + kUnfused, + kFused, +}; + +enum class QkvNormKind { + kSimplified, + kSkipSimplified, +}; + +struct QkvTrafficStats { + double input_bytes; + double skip_input_bytes; + double norm_scale_bytes; + double packed_weight_bytes; + double scale_bytes; + double intermediate_bytes; + double output_bytes; + double total_bytes; +}; + constexpr double kRtx5060TiTheoreticalBandwidthBytesPerSecond = 448.0 * 1000.0 * 1000.0 * 1000.0; constexpr int kDecodeWarmupRuns = 25; @@ -319,28 +350,61 @@ DecodeTrafficStats CalculateDecodeTrafficStats(const DecodeBenchConfig& config) }; } - MlpTrafficStats CalculateMlpTrafficStats(const MlpDecodeBenchConfig& config, MlpDecodeBenchmarkVariant variant) { - const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; - const int64_t blob_size = (config.block_size * config.bits) / 8; - - const double input_reads = variant == MlpDecodeBenchmarkVariant::kFused ? 1.0 : 2.0; - const double intermediate_bytes = - variant == MlpDecodeBenchmarkVariant::kFused ? 0.0 : 4.0 * static_cast(config.n) * sizeof(Ort::Float16_t); - const double input_bytes = input_reads * static_cast(config.k) * sizeof(Ort::Float16_t); - const double packed_weight_bytes = - 2.0 * static_cast(config.n) * static_cast(k_blocks) * static_cast(blob_size); - const double scale_bytes = 2.0 * static_cast(config.n) * static_cast(k_blocks) * sizeof(Ort::Float16_t); - const double output_bytes = static_cast(config.n) * sizeof(Ort::Float16_t); - - return { - input_bytes, - packed_weight_bytes, - scale_bytes, - intermediate_bytes, - output_bytes, - input_bytes + packed_weight_bytes + scale_bytes + intermediate_bytes + output_bytes, - }; - } +MlpTrafficStats CalculateMlpTrafficStats(const MlpDecodeBenchConfig& config, MlpDecodeBenchmarkVariant variant) { + const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; + const int64_t blob_size = (config.block_size * config.bits) / 8; + + const double input_reads = variant == MlpDecodeBenchmarkVariant::kFused ? 1.0 : 2.0; + const double intermediate_bytes = + variant == MlpDecodeBenchmarkVariant::kFused ? 0.0 : 4.0 * static_cast(config.n) * sizeof(Ort::Float16_t); + const double input_bytes = input_reads * static_cast(config.k) * sizeof(Ort::Float16_t); + const double packed_weight_bytes = + 2.0 * static_cast(config.n) * static_cast(k_blocks) * static_cast(blob_size); + const double scale_bytes = 2.0 * static_cast(config.n) * static_cast(k_blocks) * sizeof(Ort::Float16_t); + const double output_bytes = static_cast(config.n) * sizeof(Ort::Float16_t); + + return { + input_bytes, + packed_weight_bytes, + scale_bytes, + intermediate_bytes, + output_bytes, + input_bytes + packed_weight_bytes + scale_bytes + intermediate_bytes + output_bytes, + }; +} + +QkvTrafficStats CalculateQkvTrafficStats(const QkvDecodeBenchConfig& config, + QkvDecodeBenchmarkVariant variant, + QkvNormKind norm_kind) { + const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; + const int64_t blob_size = (config.block_size * config.bits) / 8; + + const double input_bytes = static_cast(config.k) * sizeof(Ort::Float16_t); + const double skip_input_bytes = norm_kind == QkvNormKind::kSkipSimplified + ? static_cast(config.k) * sizeof(Ort::Float16_t) + : 0.0; + const double norm_scale_bytes = static_cast(config.k) * sizeof(Ort::Float16_t); + const double packed_weight_bytes = + static_cast(config.q_n + 2 * config.kv_n) * static_cast(k_blocks) * static_cast(blob_size); + const double scale_bytes = + static_cast(config.q_n + 2 * config.kv_n) * static_cast(k_blocks) * sizeof(Ort::Float16_t); + const double intermediate_bytes = + variant == QkvDecodeBenchmarkVariant::kUnfused ? static_cast(config.k) * sizeof(Ort::Float16_t) : 0.0; + const double output_bytes = + static_cast(config.q_n + 2 * config.kv_n + (norm_kind == QkvNormKind::kSkipSimplified ? config.k : 0)) * + sizeof(Ort::Float16_t); + + return { + input_bytes, + skip_input_bytes, + norm_scale_bytes, + packed_weight_bytes, + scale_bytes, + intermediate_bytes, + output_bytes, + input_bytes + skip_input_bytes + norm_scale_bytes + packed_weight_bytes + scale_bytes + intermediate_bytes + output_bytes, + }; +} AdapterSelectionConfig GetAdapterSelectionConfig() { if (GetDecodeBenchmarkGpu() == DecodeBenchmarkGpu::kT1000) { @@ -585,11 +649,16 @@ std::vector GetDecodeBenchConfigs() { } std::vector GetMlpDecodeBenchConfigs() { - // Each entry is {N, K, bits, block_size, accuracy_level} for a decode-style M=1 MLP run. + // Qwen3-1.7B MLP gate/up decode geometry: hidden=2048, intermediate=6144. return { {6144, 2048, 4, 32, 4}, - {8192, 3072, 4, 32, 4}, - {11008, 4096, 4, 32, 4}, + }; +} + +std::vector GetQkvDecodeBenchConfigs() { + // Qwen3-1.7B attention projection geometry: hidden=2048, q=2048, kv=1024. + return { + {2048, 1024, 2048, 4, 32, 4}, }; } @@ -693,6 +762,78 @@ void AddMatMulNBitsSiluMulNode(ONNX_NAMESPACE::GraphProto& graph, attr_accuracy->set_i(accuracy_level); } +void AddMatMulNBitsQKVSimplifiedLayerNormNode(ONNX_NAMESPACE::GraphProto& graph, + const std::string& node_name, + const std::string& input_name, + const std::string& skip_input_name, + const std::string& norm_scale_name, + const std::string& q_weight_name, + const std::string& q_scale_name, + const std::string& k_weight_name, + const std::string& k_scale_name, + const std::string& v_weight_name, + const std::string& v_scale_name, + const std::string& q_output_name, + const std::string& k_output_name, + const std::string& v_output_name, + const std::string& skip_sum_output_name, + int64_t k, + int64_t q_n, + int64_t kv_n, + int64_t bits, + int64_t block_size, + int64_t accuracy_level, + float epsilon) { + auto* node = graph.add_node(); + node->set_name(node_name); + node->set_op_type("MatMulNBitsQKVSimplifiedLayerNorm"); + node->set_domain("com.microsoft"); + node->add_input(input_name); + node->add_input(skip_input_name); + node->add_input(norm_scale_name); + node->add_input(q_weight_name); + node->add_input(q_scale_name); + node->add_input(k_weight_name); + node->add_input(k_scale_name); + node->add_input(v_weight_name); + node->add_input(v_scale_name); + node->add_output(q_output_name); + node->add_output(k_output_name); + node->add_output(v_output_name); + if (!skip_sum_output_name.empty()) { + node->add_output(skip_sum_output_name); + } + + auto* attr_k = node->add_attribute(); + attr_k->set_name("K"); + attr_k->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + attr_k->set_i(k); + auto* attr_qn = node->add_attribute(); + attr_qn->set_name("Nq"); + attr_qn->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + attr_qn->set_i(q_n); + auto* attr_kvn = node->add_attribute(); + attr_kvn->set_name("Nkv"); + attr_kvn->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + attr_kvn->set_i(kv_n); + auto* attr_bits = node->add_attribute(); + attr_bits->set_name("bits"); + attr_bits->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + attr_bits->set_i(bits); + auto* attr_block = node->add_attribute(); + attr_block->set_name("block_size"); + attr_block->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + attr_block->set_i(block_size); + auto* attr_accuracy = node->add_attribute(); + attr_accuracy->set_name("accuracy_level"); + attr_accuracy->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + attr_accuracy->set_i(accuracy_level); + auto* attr_epsilon = node->add_attribute(); + attr_epsilon->set_name("epsilon"); + attr_epsilon->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT); + attr_epsilon->set_f(epsilon); +} + std::vector SerializeMatMulNBitsModel(const DecodeBenchConfig& config) { const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; const int64_t blob_size = (config.block_size * config.bits) / 8; @@ -760,6 +901,24 @@ std::string GetMlpDecodeBenchmarkLabel(MlpDecodeBenchmarkVariant variant) { return stream.str(); } +std::string GetQkvVariantLabel(QkvDecodeBenchmarkVariant variant) { + return variant == QkvDecodeBenchmarkVariant::kFused ? "fused" : "unfused"; +} + +std::string GetQkvNormKindLabel(QkvNormKind norm_kind) { + return norm_kind == QkvNormKind::kSkipSimplified ? "skip_simplified" : "simplified"; +} + +std::string GetQkvDecodeBenchmarkLabel(QkvDecodeBenchmarkVariant variant, QkvNormKind norm_kind) { + std::ostringstream stream; + stream << "fp16_qkv_norm_" << GetQkvNormKindLabel(norm_kind) << '_' << GetQkvVariantLabel(variant) << '_' + << (IsDecodeBenchmarkPerfMode() ? "perf" : "correctness") << '_' + << (GetDecodeBenchmarkGpu() == DecodeBenchmarkGpu::kRtx5060Ti ? "rtx" : "t") << '_' + << (IsMatMulNBitsAutoTunerEnabled() ? "tuner_on" : "tuner_off") << '_' + << (IsGraphCaptureBenchmarkEnabled() ? "graph_on" : "graph_off"); + return stream.str(); +} + std::vector SerializeMatMulNBitsMlpModel(const MlpDecodeBenchConfig& config, MlpDecodeBenchmarkVariant variant) { const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; @@ -873,6 +1032,128 @@ std::vector SerializeMatMulNBitsMlpModel(const MlpDecodeBenchConfig& co return std::vector(serialized.begin(), serialized.end()); } +std::vector SerializeMatMulNBitsQkvModel(const QkvDecodeBenchConfig& config, + QkvDecodeBenchmarkVariant variant, + QkvNormKind norm_kind) { + const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; + const int64_t blob_size = (config.block_size * config.bits) / 8; + + ONNX_NAMESPACE::ModelProto model; + model.set_ir_version(10); + + auto* onnx_opset = model.add_opset_import(); + onnx_opset->set_domain(""); + onnx_opset->set_version(21); + auto* ms_opset = model.add_opset_import(); + ms_opset->set_domain("com.microsoft"); + ms_opset->set_version(1); + + auto* graph = model.mutable_graph(); + graph->set_name(variant == QkvDecodeBenchmarkVariant::kFused + ? (norm_kind == QkvNormKind::kSkipSimplified ? "WebGpuMatMulNBitsQkvSkipNormFused" : "WebGpuMatMulNBitsQkvSimplifiedNormFused") + : (norm_kind == QkvNormKind::kSkipSimplified ? "WebGpuMatMulNBitsQkvSkipNormUnfused" : "WebGpuMatMulNBitsQkvSimplifiedNormUnfused")); + + auto* input = graph->add_input(); + input->set_name("A"); + input->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.k); + + if (norm_kind == QkvNormKind::kSkipSimplified) { + auto* skip_input = graph->add_input(); + skip_input->set_name("Skip"); + skip_input->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + skip_input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + skip_input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.k); + } + + auto add_output = [&](const std::string& name, int64_t n) { + auto* output = graph->add_output(); + output->set_name(name); + output->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(n); + }; + add_output("Q", config.q_n); + add_output("K", config.kv_n); + add_output("V", config.kv_n); + if (norm_kind == QkvNormKind::kSkipSimplified) { + add_output("SkipSum", config.k); + } + + std::vector norm_scale(static_cast(config.k), Ort::Float16_t(1.0f)); + std::vector q_b(static_cast(config.q_n * k_blocks * blob_size), uint8_t{0x11}); + std::vector k_b(static_cast(config.kv_n * k_blocks * blob_size), uint8_t{0x33}); + std::vector v_b(static_cast(config.kv_n * k_blocks * blob_size), uint8_t{0x77}); + std::vector q_scales(static_cast(config.q_n * k_blocks), Ort::Float16_t(0.03125f)); + std::vector k_scales(static_cast(config.kv_n * k_blocks), Ort::Float16_t(0.03125f)); + std::vector v_scales(static_cast(config.kv_n * k_blocks), Ort::Float16_t(0.0625f)); + + AddTensorInitializer(*graph, "norm_scale", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {config.k}, norm_scale); + AddTensorInitializer(*graph, "q_B", ONNX_NAMESPACE::TensorProto_DataType_UINT8, {config.q_n, k_blocks, blob_size}, q_b); + AddTensorInitializer(*graph, "q_scales", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {config.q_n, k_blocks}, q_scales); + AddTensorInitializer(*graph, "k_B", ONNX_NAMESPACE::TensorProto_DataType_UINT8, {config.kv_n, k_blocks, blob_size}, k_b); + AddTensorInitializer(*graph, "k_scales", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {config.kv_n, k_blocks}, k_scales); + AddTensorInitializer(*graph, "v_B", ONNX_NAMESPACE::TensorProto_DataType_UINT8, {config.kv_n, k_blocks, blob_size}, v_b); + AddTensorInitializer(*graph, "v_scales", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {config.kv_n, k_blocks}, v_scales); + + if (variant == QkvDecodeBenchmarkVariant::kFused) { + AddMatMulNBitsQKVSimplifiedLayerNormNode(*graph, + "MatMulNBitsQKVSimplifiedLayerNormDecode", + "A", + norm_kind == QkvNormKind::kSkipSimplified ? "Skip" : "", + "norm_scale", + "q_B", + "q_scales", + "k_B", + "k_scales", + "v_B", + "v_scales", + "Q", + "K", + "V", + norm_kind == QkvNormKind::kSkipSimplified ? "SkipSum" : "", + config.k, + config.q_n, + config.kv_n, + config.bits, + config.block_size, + config.accuracy_level, + 1e-6f); + } else { + AddTensorValueInfo(*graph, "A_norm", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.k}); + auto* norm = graph->add_node(); + norm->set_name(norm_kind == QkvNormKind::kSkipSimplified ? "InputSkipSimplifiedLayerNorm" : "InputSimplifiedLayerNorm"); + norm->set_op_type(norm_kind == QkvNormKind::kSkipSimplified ? "SkipSimplifiedLayerNormalization" : "SimplifiedLayerNormalization"); + if (norm_kind == QkvNormKind::kSkipSimplified) { + AddTensorValueInfo(*graph, "SkipSum", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.k}); + norm->set_domain("com.microsoft"); + norm->add_input("A"); + norm->add_input("Skip"); + norm->add_input("norm_scale"); + norm->add_output("A_norm"); + norm->add_output(""); + norm->add_output(""); + norm->add_output("SkipSum"); + } else { + norm->add_input("A"); + norm->add_input("norm_scale"); + norm->add_output("A_norm"); + } + auto* attr_epsilon = norm->add_attribute(); + attr_epsilon->set_name("epsilon"); + attr_epsilon->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT); + attr_epsilon->set_f(1e-6f); + + AddMatMulNBitsNode(*graph, "QMatMulNBitsDecode", "A_norm", "q_B", "q_scales", "Q", config.k, config.q_n, config.bits, config.block_size, config.accuracy_level); + AddMatMulNBitsNode(*graph, "KMatMulNBitsDecode", "A_norm", "k_B", "k_scales", "K", config.k, config.kv_n, config.bits, config.block_size, config.accuracy_level); + AddMatMulNBitsNode(*graph, "VMatMulNBitsDecode", "A_norm", "v_B", "v_scales", "V", config.k, config.kv_n, config.bits, config.block_size, config.accuracy_level); + } + + const auto serialized = model.SerializeAsString(); + return std::vector(serialized.begin(), serialized.end()); +} + Ort::Session CreateSessionFromModelData(const std::vector& model_data, const std::unordered_map* provider_options, GraphOptimizationLevel graph_optimization_level = GraphOptimizationLevel::ORT_ENABLE_ALL) { @@ -1004,7 +1285,47 @@ void ValidateMlpDecodeOutputs(const std::vector& unfused_model_data, << " at index " << max_abs_diff_index << std::endl; } -static void BM_WebGpuMatMulNBitsDecode(benchmark::State& state) { +void ValidateQkvDecodeOutputs(const std::vector& unfused_model_data, + const std::vector& fused_model_data, + const std::unordered_map& provider_options, + const char* const* input_names, + const Ort::Value* input_tensors, + size_t input_count, + const char* const* output_names, + size_t output_count) { + Ort::Session unfused_session = CreateSessionFromModelData(unfused_model_data, &provider_options, GraphOptimizationLevel::ORT_DISABLE_ALL); + Ort::Session fused_session = CreateSessionFromModelData(fused_model_data, &provider_options, GraphOptimizationLevel::ORT_DISABLE_ALL); + + auto unfused_outputs = unfused_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensors, input_count, output_names, output_count); + auto fused_outputs = fused_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensors, input_count, output_names, output_count); + + for (size_t output_index = 0; output_index < output_count; ++output_index) { + const size_t element_count = unfused_outputs[output_index].GetTensorTypeAndShapeInfo().GetElementCount(); + const auto* unfused_data = unfused_outputs[output_index].GetTensorData(); + const auto* fused_data = fused_outputs[output_index].GetTensorData(); + for (size_t i = 0; i < element_count; ++i) { + const float unfused_value = unfused_data[i].ToFloat(); + const float fused_value = fused_data[i].ToFloat(); + const float abs_diff = std::abs(unfused_value - fused_value); + const float allowed_diff = kDecodeCorrectnessAbsTolerance + + kDecodeCorrectnessRelTolerance * std::max(std::abs(unfused_value), std::abs(fused_value)); + if (abs_diff > allowed_diff) { + std::ostringstream stream; + stream << "QKV decode correctness check failed on output " << output_index + << " at index " << i + << ": unfused=" << unfused_value + << ", fused=" << fused_value + << ", abs_diff=" << abs_diff + << ", allowed_diff=" << allowed_diff; + throw std::runtime_error(stream.str()); + } + } + } + + std::cout << "QKV decode correctness check passed." << std::endl; +} + +[[maybe_unused]] static void BM_WebGpuMatMulNBitsDecode(benchmark::State& state) { try { const DecodeBenchConfig config{ state.range(0), @@ -1088,6 +1409,112 @@ static void BM_WebGpuMatMulNBitsDecode(benchmark::State& state) { } } +void BenchmarkWebGpuMatMulNBitsQkvDecode(benchmark::State& state, QkvDecodeBenchmarkVariant variant, QkvNormKind norm_kind) { + try { + const QkvDecodeBenchConfig config{ + state.range(0), + state.range(1), + state.range(2), + state.range(3), + state.range(4), + state.range(5), + }; + + if (config.k % config.block_size != 0) { + state.SkipWithError("K must be divisible by block_size for this benchmark skeleton."); + return; + } + + const QkvTrafficStats traffic = CalculateQkvTrafficStats(config, variant, norm_kind); + std::vector model_data = SerializeMatMulNBitsQkvModel(config, variant, norm_kind); + const SelectedWebGpuContext& selected_context = GetSelectedWebGpuContext(); + Ort::Session session = CreateSessionFromModelData(model_data, + &selected_context.provider_options, + GraphOptimizationLevel::ORT_DISABLE_ALL); + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + std::vector input_shape{1, config.k}; + std::vector activation(static_cast(config.k)); + std::vector skip_activation(static_cast(config.k)); + + std::mt19937 rng(123); + std::uniform_real_distribution dist(-1.0f, 1.0f); + for (auto& value : activation) { + value = Ort::Float16_t(dist(rng)); + } + for (auto& value : skip_activation) { + value = Ort::Float16_t(dist(rng)); + } + + const char* simplified_input_names[] = {"A"}; + const char* skip_input_names[] = {"A", "Skip"}; + const char* simplified_output_names[] = {"Q", "K", "V"}; + const char* skip_output_names[] = {"Q", "K", "V", "SkipSum"}; + const char* const* input_names = norm_kind == QkvNormKind::kSkipSimplified ? skip_input_names : simplified_input_names; + const char* const* output_names = norm_kind == QkvNormKind::kSkipSimplified ? skip_output_names : simplified_output_names; + const size_t input_count = norm_kind == QkvNormKind::kSkipSimplified ? 2u : 1u; + const size_t output_count = norm_kind == QkvNormKind::kSkipSimplified ? 4u : 3u; + + std::array input_tensors = { + Ort::Value::CreateTensor(memory_info, + activation.data(), + activation.size(), + input_shape.data(), + input_shape.size()), + Ort::Value::CreateTensor(memory_info, + skip_activation.data(), + skip_activation.size(), + input_shape.data(), + input_shape.size())}; + Ort::RunOptions run_options = CreateBenchmarkRunOptions(); + + if (!IsDecodeBenchmarkPerfMode() && variant == QkvDecodeBenchmarkVariant::kFused) { + ValidateQkvDecodeOutputs(SerializeMatMulNBitsQkvModel(config, QkvDecodeBenchmarkVariant::kUnfused, norm_kind), + model_data, + selected_context.provider_options, + input_names, + input_tensors.data(), + input_count, + output_names, + output_count); + } + + for (int i = 0; i < kDecodeWarmupRuns; ++i) { + auto warmup_outputs = session.Run(run_options, input_names, input_tensors.data(), input_count, output_names, output_count); + benchmark::DoNotOptimize(warmup_outputs); + } + + double total_kernel_seconds = 0.0; + for (auto _ : state) { + const auto kernel_start = std::chrono::steady_clock::now(); + auto outputs = session.Run(run_options, input_names, input_tensors.data(), input_count, output_names, output_count); + const auto kernel_end = std::chrono::steady_clock::now(); + total_kernel_seconds += std::chrono::duration(kernel_end - kernel_start).count(); + benchmark::DoNotOptimize(outputs); + } + + const double total_flops = 2.0 * static_cast(config.k) * static_cast(config.q_n + 2 * config.kv_n); + const double achieved_bandwidth_bytes_per_second = + total_kernel_seconds > 0.0 + ? traffic.total_bytes * static_cast(state.iterations()) / total_kernel_seconds + : 0.0; + + state.SetLabel(GetQkvDecodeBenchmarkLabel(variant, norm_kind)); + state.counters["TFLOPS"] = benchmark::Counter(total_flops, benchmark::Counter::kIsIterationInvariantRate); + state.counters["ApproxMemBW_GBps"] = benchmark::Counter(achieved_bandwidth_bytes_per_second / 1.0e9); + state.counters["ApproxTraffic_MB"] = benchmark::Counter(traffic.total_bytes / 1.0e6); + state.counters["Input_MB"] = benchmark::Counter(traffic.input_bytes / 1.0e6); + state.counters["SkipInput_MB"] = benchmark::Counter(traffic.skip_input_bytes / 1.0e6); + state.counters["NormScale_MB"] = benchmark::Counter(traffic.norm_scale_bytes / 1.0e6); + state.counters["PackedW_MB"] = benchmark::Counter(traffic.packed_weight_bytes / 1.0e6); + state.counters["Scales_MB"] = benchmark::Counter(traffic.scale_bytes / 1.0e6); + state.counters["Intermediate_MB"] = benchmark::Counter(traffic.intermediate_bytes / 1.0e6); + state.counters["Output_MB"] = benchmark::Counter(traffic.output_bytes / 1.0e6); + state.counters["GraphReplay"] = benchmark::Counter(IsGraphCaptureBenchmarkEnabled() ? 1.0 : 0.0); + } catch (const std::exception& ex) { + state.SkipWithError(ex.what()); + } +} + void BenchmarkWebGpuMatMulNBitsMlpDecode(benchmark::State& state, MlpDecodeBenchmarkVariant variant) { try { const MlpDecodeBenchConfig config{ @@ -1186,7 +1613,23 @@ static void BM_WebGpuMatMulNBitsMlpDecodeFused(benchmark::State& state) { BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kFused); } -void ApplyWebGpuMatMulNBitsDecodeArgs(benchmark::internal::Benchmark* benchmark) { +static void BM_WebGpuMatMulNBitsQkvDecodeUnfused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kUnfused, QkvNormKind::kSimplified); +} + +static void BM_WebGpuMatMulNBitsQkvDecodeFused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kFused, QkvNormKind::kSimplified); +} + +static void BM_WebGpuMatMulNBitsQkvSkipDecodeUnfused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kUnfused, QkvNormKind::kSkipSimplified); +} + +static void BM_WebGpuMatMulNBitsQkvSkipDecodeFused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kFused, QkvNormKind::kSkipSimplified); +} + +[[maybe_unused]] void ApplyWebGpuMatMulNBitsDecodeArgs(benchmark::internal::Benchmark* benchmark) { for (const auto& config : GetDecodeBenchConfigs()) { benchmark->Args({config.n, config.k, config.bits, config.block_size, config.accuracy_level}); } @@ -1198,12 +1641,18 @@ void ApplyWebGpuMatMulNBitsMlpDecodeArgs(benchmark::internal::Benchmark* benchma } } -BENCHMARK(BM_WebGpuMatMulNBitsDecode) - ->Apply(ApplyWebGpuMatMulNBitsDecodeArgs) - ->Repetitions(5) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); +void ApplyWebGpuMatMulNBitsQkvDecodeArgs(benchmark::internal::Benchmark* benchmark) { + for (const auto& config : GetQkvDecodeBenchConfigs()) { + benchmark->Args({config.q_n, config.kv_n, config.k, config.bits, config.block_size, config.accuracy_level}); + } +} + +// BENCHMARK(BM_WebGpuMatMulNBitsDecode) +// ->Apply(ApplyWebGpuMatMulNBitsDecodeArgs) +// ->Repetitions(5) +// ->ReportAggregatesOnly() +// ->UseRealTime() +// ->Unit(benchmark::TimeUnit::kMicrosecond); BENCHMARK(BM_WebGpuMatMulNBitsMlpDecodeUnfused) ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) @@ -1219,4 +1668,32 @@ BENCHMARK(BM_WebGpuMatMulNBitsMlpDecodeFused) ->UseRealTime() ->Unit(benchmark::TimeUnit::kMicrosecond); +BENCHMARK(BM_WebGpuMatMulNBitsQkvDecodeUnfused) + ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) + ->Repetitions(5) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +BENCHMARK(BM_WebGpuMatMulNBitsQkvDecodeFused) + ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) + ->Repetitions(5) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +BENCHMARK(BM_WebGpuMatMulNBitsQkvSkipDecodeUnfused) + ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) + ->Repetitions(5) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +BENCHMARK(BM_WebGpuMatMulNBitsQkvSkipDecodeFused) + ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) + ->Repetitions(5) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + } // namespace diff --git a/onnxruntime/test/optimizer/graph_transform_utils_test.cc b/onnxruntime/test/optimizer/graph_transform_utils_test.cc index ffa79abb4dc09..75b4c57c670a5 100644 --- a/onnxruntime/test/optimizer/graph_transform_utils_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_utils_test.cc @@ -20,6 +20,8 @@ namespace test { namespace { constexpr const char* kOrtEnableMatMulNBitsSiluFusionEnvVar = "ORT_ENABLE_MATMUL_NBITS_SILU_FUSION"; +constexpr const char* kOrtEnableMatMulNBitsQKVSimplifiedLayerNormFusionEnvVar = + "ORT_ENABLE_MATMUL_NBITS_QKV_SIMPLIFIED_LAYER_NORM_FUSION"; bool HasTransformerNamed(const InlinedVector>& transformers, std::string_view name) { @@ -89,7 +91,8 @@ TEST(GraphTransformerUtilsTests, MatMulNBitsSiluFusionDisabledByDefault) { #if defined(DISABLE_CONTRIB_OPS) GTEST_SKIP() << "MatMulNBitsSiluFusion requires contrib ops."; #else - ScopedEnvironmentVariables scoped_env_vars{{kOrtEnableMatMulNBitsSiluFusionEnvVar, {}}}; + const EnvVarMap env_vars{{kOrtEnableMatMulNBitsSiluFusionEnvVar, optional{}}}; + ScopedEnvironmentVariables scoped_env_vars{env_vars}; CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); const auto& logger = DefaultLoggingManager().DefaultLogger(); @@ -103,7 +106,8 @@ TEST(GraphTransformerUtilsTests, MatMulNBitsSiluFusionEnabledViaEnvironmentVaria #if defined(DISABLE_CONTRIB_OPS) GTEST_SKIP() << "MatMulNBitsSiluFusion requires contrib ops."; #else - ScopedEnvironmentVariables scoped_env_vars{{kOrtEnableMatMulNBitsSiluFusionEnvVar, std::string{"1"}}}; + const EnvVarMap env_vars{{kOrtEnableMatMulNBitsSiluFusionEnvVar, optional{"1"}}}; + ScopedEnvironmentVariables scoped_env_vars{env_vars}; CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); const auto& logger = DefaultLoggingManager().DefaultLogger(); @@ -113,6 +117,36 @@ TEST(GraphTransformerUtilsTests, MatMulNBitsSiluFusionEnabledViaEnvironmentVaria #endif } +TEST(GraphTransformerUtilsTests, MatMulNBitsQKVSimplifiedLayerNormFusionDisabledByDefault) { +#if defined(DISABLE_CONTRIB_OPS) + GTEST_SKIP() << "MatMulNBitsQKVSimplifiedLayerNormFusion requires contrib ops."; +#else + const EnvVarMap env_vars{{kOrtEnableMatMulNBitsQKVSimplifiedLayerNormFusionEnvVar, optional{}}}; + ScopedEnvironmentVariables scoped_env_vars{env_vars}; + + CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); + const auto& logger = DefaultLoggingManager().DefaultLogger(); + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger); + + EXPECT_FALSE(HasTransformerNamed(transformers, "MatMulNBitsQKVSimplifiedLayerNormFusion")); +#endif +} + +TEST(GraphTransformerUtilsTests, MatMulNBitsQKVSimplifiedLayerNormFusionEnabledViaEnvironmentVariable) { +#if defined(DISABLE_CONTRIB_OPS) + GTEST_SKIP() << "MatMulNBitsQKVSimplifiedLayerNormFusion requires contrib ops."; +#else + const EnvVarMap env_vars{{kOrtEnableMatMulNBitsQKVSimplifiedLayerNormFusionEnvVar, optional{"1"}}}; + ScopedEnvironmentVariables scoped_env_vars{env_vars}; + + CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); + const auto& logger = DefaultLoggingManager().DefaultLogger(); + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger); + + EXPECT_TRUE(HasTransformerNamed(transformers, "MatMulNBitsQKVSimplifiedLayerNormFusion")); +#endif +} + TEST(GraphTransformerUtilsTests, TestDQMatMulNBitsFusionConfigWithContribGating) { SessionOptions session_options; const auto status = session_options.config_options.AddConfigEntry( diff --git a/onnxruntime/test/optimizer/matmul_nbits_qkv_sln_fusion_test.cc b/onnxruntime/test/optimizer/matmul_nbits_qkv_sln_fusion_test.cc new file mode 100644 index 0000000000000..b6e936ab883ad --- /dev/null +++ b/onnxruntime/test/optimizer/matmul_nbits_qkv_sln_fusion_test.cc @@ -0,0 +1,228 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/node_attr_utils.h" +#include "core/optimizer/graph_transformer_mgr.h" +#include "core/optimizer/matmul_nbits_qkv_sln_fusion.h" +#include "core/optimizer/utils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +#include "test/util/include/asserts.h" +#include "test/util/include/default_providers.h" +#include "test/unittest_util/framework_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" +#include "test/optimizer/graph_transform_test_fixture.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +#if !defined(DISABLE_CONTRIB_OPS) + +namespace { + +void SetWebGpuProvider(Node& node) { + node.SetExecutionProviderType(kWebGpuExecutionProvider); +} + +NodeAttributes MakeMatMulNBitsAttrs(int64_t k, int64_t n, int64_t block_size, int64_t bits, int64_t accuracy_level) { + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("K", k), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("N", n), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("bits", bits), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level), attrs); + return attrs; +} + +Status CheckMatMulNBitsQKVSimplifiedLayerNormFusedGraphImpl(const Graph& graph, bool expect_skip_sln_output) { + const auto op_to_count = CountOpsInGraph(graph); + if (OpCount(op_to_count, "com.microsoft.MatMulNBitsQKVSimplifiedLayerNorm") != 1 || + OpCount(op_to_count, "SimplifiedLayerNormalization") != 0 || + OpCount(op_to_count, "com.microsoft.SkipSimplifiedLayerNormalization") != 0 || + OpCount(op_to_count, "com.microsoft.MatMulNBits") != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Unexpected operator counts after MatMulNBitsQKVSimplifiedLayerNormFusion."); + } + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBitsQKVSimplifiedLayerNorm") { + ORT_RETURN_IF_NOT(node.Domain() == kMSDomain, "Fused node must be in com.microsoft domain."); + ORT_RETURN_IF_NOT(node.GetExecutionProviderType() == kWebGpuExecutionProvider, + "Fused node must be assigned to WebGPU EP."); + ORT_RETURN_IF_NOT(node.InputDefs().size() == 9, "Fused node must expose the 9-input contract."); + ORT_RETURN_IF_NOT(node.OutputDefs().size() == (expect_skip_sln_output ? 4u : 3u), + "Fused node outputs did not match the expected simplified vs skip-simplified contract."); + } + } + + return Status::OK(); +} + +Status CheckMatMulNBitsQKVSimplifiedLayerNormFusedGraph(Graph& graph) { + return CheckMatMulNBitsQKVSimplifiedLayerNormFusedGraphImpl(static_cast(graph), false); +} + +Status CheckMatMulNBitsQKVSimplifiedLayerNormSkipFusedGraph(Graph& graph) { + return CheckMatMulNBitsQKVSimplifiedLayerNormFusedGraphImpl(static_cast(graph), true); +} + +void BuildMatMulNBitsQKVSimplifiedLayerNormWebGpuPatternImpl(ModelTestBuilder& builder, bool with_skip_input) { + constexpr int64_t k = 16; + constexpr int64_t q_n = 8; + constexpr int64_t kv_n = 4; + constexpr int64_t block_size = 32; + constexpr int64_t bits = 4; + constexpr int64_t accuracy_level = 4; + constexpr int64_t blob_size = block_size * bits / 8; + + NodeArg* input = builder.MakeInput( + std::vector{1, k}, + std::vector{ + MLFloat16(-1.0f), MLFloat16(-0.875f), MLFloat16(-0.75f), MLFloat16(-0.625f), + MLFloat16(-0.5f), MLFloat16(-0.375f), MLFloat16(-0.25f), MLFloat16(-0.125f), + MLFloat16(0.125f), MLFloat16(0.25f), MLFloat16(0.375f), MLFloat16(0.5f), + MLFloat16(0.625f), MLFloat16(0.75f), MLFloat16(0.875f), MLFloat16(1.0f)}); + NodeArg* skip_input = with_skip_input + ? builder.MakeInput( + std::vector{1, k}, + std::vector{ + MLFloat16(1.0f), MLFloat16(0.875f), MLFloat16(0.75f), MLFloat16(0.625f), + MLFloat16(0.5f), MLFloat16(0.375f), MLFloat16(0.25f), MLFloat16(0.125f), + MLFloat16(-0.125f), MLFloat16(-0.25f), MLFloat16(-0.375f), MLFloat16(-0.5f), + MLFloat16(-0.625f), MLFloat16(-0.75f), MLFloat16(-0.875f), MLFloat16(-1.0f)}) + : nullptr; + + NodeArg* norm_scale = builder.MakeInitializer({k}, MLFloat16(1.0f), MLFloat16(1.0f)); + NodeArg* q_weight = builder.MakeInitializer({q_n, 1, blob_size}, uint8_t{0}, uint8_t{15}); + NodeArg* q_scale = builder.MakeInitializer({q_n, 1}, MLFloat16(1.0f), MLFloat16(1.0f)); + NodeArg* k_weight = builder.MakeInitializer({kv_n, 1, blob_size}, uint8_t{0}, uint8_t{15}); + NodeArg* k_scale = builder.MakeInitializer({kv_n, 1}, MLFloat16(1.0f), MLFloat16(1.0f)); + NodeArg* v_weight = builder.MakeInitializer({kv_n, 1, blob_size}, uint8_t{0}, uint8_t{15}); + NodeArg* v_scale = builder.MakeInitializer({kv_n, 1}, MLFloat16(1.0f), MLFloat16(1.0f)); + NodeArg* optional_tensor = builder.MakeOptionalTensor(); + + NodeArg* norm_out = builder.MakeIntermediate(std::vector{1, k}); + NodeArg* optional_norm_output_1 = builder.MakeOptionalTensor(); + NodeArg* optional_norm_output_2 = builder.MakeOptionalTensor(); + NodeArg* residual_out = with_skip_input ? builder.MakeIntermediate(std::vector{1, k}) : nullptr; + NodeArg* q_output = builder.MakeOutput(std::vector{1, q_n}); + NodeArg* k_output = builder.MakeOutput(std::vector{1, kv_n}); + NodeArg* v_output = builder.MakeOutput(std::vector{1, kv_n}); + NodeArg* residual_passthrough = with_skip_input ? builder.MakeOutput(std::vector{1, k}) : nullptr; + + NodeAttributes q_attrs = MakeMatMulNBitsAttrs(k, q_n, block_size, bits, accuracy_level); + NodeAttributes kv_attrs = MakeMatMulNBitsAttrs(k, kv_n, block_size, bits, accuracy_level); + + Node& norm = with_skip_input + ? builder.AddNode("SkipSimplifiedLayerNormalization", {input, skip_input, norm_scale}, {norm_out, optional_norm_output_1, optional_norm_output_2, residual_out}, kMSDomain) + : builder.AddNode("SimplifiedLayerNormalization", {input, norm_scale}, {norm_out}); + norm.AddAttribute("epsilon", 1e-6f); + + Node& q_matmul = builder.AddNode("MatMulNBits", {norm_out, q_weight, q_scale, optional_tensor, optional_tensor, optional_tensor}, {q_output}, kMSDomain, &q_attrs); + Node& k_matmul = builder.AddNode("MatMulNBits", {norm_out, k_weight, k_scale, optional_tensor, optional_tensor, optional_tensor}, {k_output}, kMSDomain, &kv_attrs); + Node& v_matmul = builder.AddNode("MatMulNBits", {norm_out, v_weight, v_scale, optional_tensor, optional_tensor, optional_tensor}, {v_output}, kMSDomain, &kv_attrs); + + SetWebGpuProvider(norm); + SetWebGpuProvider(q_matmul); + SetWebGpuProvider(k_matmul); + SetWebGpuProvider(v_matmul); + + if (with_skip_input) { + Node& residual_identity = builder.AddNode("Identity", {residual_out}, {residual_passthrough}); + SetWebGpuProvider(residual_identity); + } +} + +void BuildMatMulNBitsQKVSimplifiedLayerNormWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsQKVSimplifiedLayerNormWebGpuPatternImpl(builder, false); +} + +void BuildMatMulNBitsQKVSimplifiedLayerNormSkipWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsQKVSimplifiedLayerNormWebGpuPatternImpl(builder, true); +} + +} // namespace + +TEST_F(GraphTransformationTests, MatMulNBitsQKVSimplifiedLayerNormFusionFusesWebGpuPattern) { + ASSERT_STATUS_OK(TestGraphTransformer( + BuildMatMulNBitsQKVSimplifiedLayerNormWebGpuPattern, + 21, + *logger_, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + TransformerLevel::Level2, + 1, + nullptr, + CheckMatMulNBitsQKVSimplifiedLayerNormFusedGraph)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsQKVSimplifiedLayerNormFusionMatchesUnfusedWebGpuResults) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsQKVSimplifiedLayerNormFusedGraphImpl(session.GetGraph(), false)); + }; + + TransformerTester( + BuildMatMulNBitsQKVSimplifiedLayerNormWebGpuPattern, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + {}, + {}, + std::move(webgpu_ep)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsQKVSimplifiedLayerNormFusionFusesSkipWebGpuPattern) { + ASSERT_STATUS_OK(TestGraphTransformer( + BuildMatMulNBitsQKVSimplifiedLayerNormSkipWebGpuPattern, + 21, + *logger_, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + TransformerLevel::Level2, + 1, + nullptr, + CheckMatMulNBitsQKVSimplifiedLayerNormSkipFusedGraph)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsQKVSimplifiedLayerNormFusionMatchesUnfusedSkipWebGpuResults) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto add_session_options = [](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableSpecifiedOptimizers, + "EliminateIdentity")); + }; + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsQKVSimplifiedLayerNormFusedGraphImpl(session.GetGraph(), true)); + }; + + TransformerTester( + BuildMatMulNBitsQKVSimplifiedLayerNormSkipWebGpuPattern, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + add_session_options, + {}, + std::move(webgpu_ep)); +} + +#endif // !defined(DISABLE_CONTRIB_OPS) + +} // namespace test +} // namespace onnxruntime From b67ae8116daa2448c179d08cb31ba9c2dead935d Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Wed, 22 Apr 2026 18:27:23 -0700 Subject: [PATCH 07/26] Skip + MatmulNBitsSilu fusion - works and good perf --- .../webgpu/quantization/matmul_nbits_silu.cc | 385 ++++++++++++++++-- .../matmul_nbits_silu_mul.wgsl.template | 96 ++++- .../core/graph/contrib_ops/contrib_defs.cc | 64 ++- .../core/optimizer/graph_transformer_utils.cc | 4 +- .../optimizer/matmul_nbits_silu_fusion.cc | 177 +++++++- .../webgpu_matmul_nbits_decode.cc | 289 +++++++++++-- .../matmul_nbits_silu_fusion_test.cc | 237 ++++++++++- 7 files changed, 1137 insertions(+), 115 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc index 86e8d0c0db964..9b4f456c27894 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc @@ -7,8 +7,10 @@ #include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" #include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" #include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" +#include "contrib_ops/webgpu/bert/skip_layer_norm.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" #include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/webgpu/nn/layer_norm.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" #include "core/providers/webgpu/webgpu_utils.h" @@ -23,6 +25,7 @@ constexpr unsigned int kMinMForTileOptimization = 4; constexpr uint32_t kFusedDecodeFastPathBits = 4u; constexpr uint32_t kFusedDecodeFastPathBlockSize = 32u; +constexpr float kSkipSimplifiedLayerNormEpsilon = 1e-05f; bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(const Tensor* a, int64_t K_op, @@ -116,11 +119,129 @@ bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(const Tensor* a, CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a); } +TensorShape GetOverrideShape(const TensorShape& shape, int components) { + return TensorShape{shape.Size() / components}; +} + +Status ApplySimplifiedLayerNorm(const Tensor* x, + const Tensor* scale, + float epsilon, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y) { + const auto& x_shape = x->Shape(); + if (x_shape.Size() == 0) { + return Status::OK(); + } + + const int64_t norm_size = x_shape[x_shape.NumDimensions() - 1]; + const uint32_t norm_count = onnxruntime::narrow(x_shape.Size() / norm_size); + const int components = GetMaxComponents(norm_size); + const uint32_t norm_size_vectorized = onnxruntime::narrow((norm_size + components - 1) / components); + const bool split_norm_dim = norm_size % 512 == 0 && norm_count == 1; + + onnxruntime::webgpu::LayerNormProgram program{/*has_bias=*/false, + /*simplified=*/true, + /*has_mean_output=*/false, + /*has_inv_std_dev_output=*/false, + split_norm_dim}; + + program.CacheHint(components, true, split_norm_dim) + .AddInputs({{x, ProgramTensorMetadataDependency::Type, GetOverrideShape(x_shape, components), components}, + {scale, ProgramTensorMetadataDependency::Type, GetOverrideShape(scale->Shape(), components), components}}) + .AddOutputs({{y, ProgramTensorMetadataDependency::None, GetOverrideShape(y->Shape(), components), components}}) + .AddUniformVariables({{static_cast(components)}, + {norm_count}, + {static_cast(norm_size)}, + {norm_size_vectorized}, + {epsilon}}); + + if (split_norm_dim) { + const uint32_t workgroup_size_x = 128; + const uint32_t dispatch_size_x = onnxruntime::narrow(norm_size / (workgroup_size_x * components)); + program.SetDispatchGroupSize(dispatch_size_x, 1, 1) + .SetWorkgroupSize(workgroup_size_x); + } else { + program.SetDispatchGroupSize(norm_count); + } + + return context.RunProgram(program); +} + +Status ApplySkipSimplifiedLayerNorm(const Tensor* x, + const Tensor* skip, + const Tensor* scale, + float epsilon, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + Tensor* input_skip_bias_sum) { + const auto& x_shape = x->Shape(); + if (x_shape.Size() == 0) { + return Status::OK(); + } + + const uint32_t hidden_size = onnxruntime::narrow(x_shape[x_shape.NumDimensions() - 1]); + const int components = GetMaxComponents(hidden_size); + const uint32_t norm_count = onnxruntime::narrow(x_shape.SizeToDimension(x_shape.NumDimensions() - 1)); + const bool split_hidden_dim = hidden_size % 512 == 0 && norm_count == 1; + const uint32_t skip_size = onnxruntime::narrow(skip->Shape().Size()); + + SkipLayerNormProgram program{/*hasBeta=*/false, + /*hasBias=*/false, + epsilon, + hidden_size, + input_skip_bias_sum != nullptr, + /*simplified=*/true, + split_hidden_dim}; + program + .CacheHint(/*simplified=*/true, input_skip_bias_sum != nullptr, split_hidden_dim) + .AddInputs({{x, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{scale, ProgramTensorMetadataDependency::Type, components}}) + .AddOutputs({{y, ProgramTensorMetadataDependency::None, components}}) + .SetDispatchGroupSize(onnxruntime::narrow(ceil(1.0 * x_shape.Size() / hidden_size))) + .AddUniformVariables({{static_cast(components)}}) + .AddUniformVariables({{hidden_size}}) + .AddUniformVariables({{epsilon}}) + .AddUniformVariables({{skip_size}}); + + if (split_hidden_dim) { + const uint32_t workgroup_size_x = 128; + const uint32_t dispatch_size_x = (input_skip_bias_sum != nullptr ? 2u : 1u) * hidden_size / + (workgroup_size_x * components); + program.SetDispatchGroupSize(dispatch_size_x, 1, 1) + .SetWorkgroupSize(workgroup_size_x); + } + + if (input_skip_bias_sum != nullptr) { + program.AddOutputs({{input_skip_bias_sum, ProgramTensorMetadataDependency::None, components}}); + } + + return context.RunProgram(program); +} + +Status ApplyUnfusedSiluMul(const Tensor* a, + const Tensor* gate_b, + const Tensor* gate_scales, + const Tensor* gate_bias, + const Tensor* up_b, + const Tensor* up_scales, + const Tensor* up_bias, + int64_t K, + int64_t N, + int64_t block_size, + int64_t accuracy_level, + int64_t bits, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y); + class MatMulNBitsSiluMulDecodeProgram final : public Program { public: MatMulNBitsSiluMulDecodeProgram(uint32_t tile_size, bool has_gate_bias, bool has_up_bias, + bool has_norm_input, + bool has_skip_input, + bool has_skip_output, bool single_scale_weights, uint32_t tile_size_k_vec, uint32_t k_unroll_tiles, @@ -130,6 +251,9 @@ class MatMulNBitsSiluMulDecodeProgram final : public ProgramShape(), b_shape, false, true)); + const auto output_shape = helper.OutputShape(); + + Tensor gate_output = context.CreateGPUTensor(a->DataType(), output_shape); + Tensor up_output = context.CreateGPUTensor(a->DataType(), output_shape); + + ORT_RETURN_IF_ERROR(ApplyMatMulNBits(a, gate_b, gate_scales, nullptr, gate_bias, K, N, block_size, accuracy_level, bits, context, &gate_output)); + ORT_RETURN_IF_ERROR(ApplyMatMulNBits(a, up_b, up_scales, nullptr, up_bias, K, N, block_size, accuracy_level, bits, context, &up_output)); + + const uint32_t data_size = onnxruntime::narrow(y->Shape().Size()); + const uint32_t vec_size = (data_size + 3u) / 4u; + MatMulNBitsSiluMulProgram program; + program + .AddInputs({{&gate_output, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, 4}, + {&up_output, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, 4}}) + .AddOutput({y, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({vec_size}); + + return context.RunProgram(program); +} + } // namespace ONNX_OPERATOR_KERNEL_EX( @@ -287,12 +523,17 @@ ONNX_OPERATOR_KERNEL_EX( Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* a = context.Input(0); - const Tensor* gate_b = context.Input(1); - const Tensor* gate_scales = context.Input(2); - const Tensor* gate_bias = context.Input(3); - const Tensor* up_b = context.Input(4); - const Tensor* up_scales = context.Input(5); - const Tensor* up_bias = context.Input(6); + const Tensor* skip = context.Input(1); + const Tensor* norm_scale = context.Input(2); + const Tensor* gate_b = context.Input(3); + const Tensor* gate_scales = context.Input(4); + const Tensor* gate_bias = context.Input(5); + const Tensor* up_b = context.Input(6); + const Tensor* up_scales = context.Input(7); + const Tensor* up_bias = context.Input(8); + + ORT_ENFORCE(skip == nullptr || norm_scale != nullptr, + "MatMulNBitsSiluMul requires norm_scale when skip is present."); MatMulComputeHelper helper; TensorShape b_shape({N_, K_}); @@ -315,11 +556,22 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& const uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; Tensor* y = context.Output(0, output_shape); + Tensor* input_skip_bias_sum = skip != nullptr ? context.Output(1, a->Shape()) : nullptr; const uint32_t data_size = onnxruntime::narrow(y->Shape().Size()); if (data_size == 0) { return Status::OK(); } + if (norm_scale != nullptr) { + ORT_ENFORCE(norm_scale->Shape().Size() == K_, "norm_scale must have shape [K]."); + } + + const bool is_decode_fast_path_candidate = + M == 1 && + bits_ == kFusedDecodeFastPathBits && + block_size == kFusedDecodeFastPathBlockSize; + const bool has_norm_input = norm_scale != nullptr; + const bool would_use_subgroup_unfused = WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(a, K_, @@ -347,11 +599,13 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& // The experimental wide M==1 fused path regressed badly on NVIDIA decode shapes. // Keep the implementation around for future work, but do not dispatch to it. - if (!would_use_subgroup_unfused && + const bool can_use_decode_fast_path = + is_decode_fast_path_candidate && + !would_use_subgroup_unfused && !would_use_dp4a_unfused && - !would_use_wide_tile_unfused && - M == 1 && bits_ == kFusedDecodeFastPathBits && - block_size == kFusedDecodeFastPathBlockSize) { + !would_use_wide_tile_unfused; + + if (can_use_decode_fast_path) { //ORT_ENFORCE(false, "The experimental wide M==1 fused path regressed badly on NVIDIA decode shapes. Keep the implementation around for future work, but do not dispatch to it."); ORT_ENFORCE(bits_ == kFusedDecodeFastPathBits, "MatMulNBitsSiluMulDecodeProgram is specialized for 4-bit weights only."); @@ -360,6 +614,8 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& const bool has_gate_bias = gate_bias != nullptr; const bool has_up_bias = up_bias != nullptr; + const bool has_skip_input = skip != nullptr; + const bool has_skip_output = input_skip_bias_sum != nullptr; uint32_t workgroup_size = 128; uint32_t tile_size = 8; uint32_t tile_size_k_vec = @@ -393,6 +649,9 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& MatMulNBitsSiluMulDecodeProgram program{tile_size, has_gate_bias, has_up_bias, + has_norm_input, + has_skip_input, + has_skip_output, single_scale_weights, tile_size_k_vec, k_unroll_tiles, @@ -400,9 +659,15 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& has_full_k_tiles}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize(num_N_tile, 1, batch_count); + program.AddInput({a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); + if (has_skip_input) { + program.AddInput({skip, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); + } + if (has_norm_input) { + program.AddInput({norm_scale, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); + } program - .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}, - {gate_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + .AddInputs({{gate_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, {gate_scales, ProgramTensorMetadataDependency::TypeAndRank}, {up_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, {up_scales, ProgramTensorMetadataDependency::TypeAndRank}}) @@ -414,15 +679,25 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& {block_size}, {n_blocks_per_col}, {num_N_tile}, - {batch_count}}) + {batch_count}, + {has_skip_input ? onnxruntime::narrow(skip->Shape().Size()) : 0u}, + {kSkipSimplifiedLayerNormEpsilon}}) .CacheHint(single_scale_weights, - has_gate_bias, - has_up_bias, - tile_size_k_vec, - k_unroll_tiles, - has_full_n_tiles, - has_full_k_tiles, - "decode_4bit"); + has_gate_bias, + has_up_bias, + has_norm_input, + has_skip_input, + has_skip_output, + tile_size_k_vec, + k_unroll_tiles, + has_full_n_tiles, + has_full_k_tiles, + "decode_4bit"); + if (has_skip_output) { + program.AddOutput({input_skip_bias_sum, + ProgramTensorMetadataDependency::TypeAndRank, + static_cast(components_a)}); + } if (has_gate_bias) { program.AddInput({gate_bias, ProgramTensorMetadataDependency::None}); } @@ -433,23 +708,59 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& return context.RunProgram(program); } - Tensor gate_output = context.CreateGPUTensor(a->DataType(), output_shape); - Tensor up_output = context.CreateGPUTensor(a->DataType(), output_shape); - - //ORT_ENFORCE(false, "Reached prefill."); - ORT_RETURN_IF_ERROR(ApplyMatMulNBits(a, gate_b, gate_scales, nullptr, gate_bias, K_, N_, block_size_, accuracy_level_, bits_, context, &gate_output)); - ORT_RETURN_IF_ERROR(ApplyMatMulNBits(a, up_b, up_scales, nullptr, up_bias, K_, N_, block_size_, accuracy_level_, bits_, context, &up_output)); + if (skip != nullptr) { + Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape()); + ORT_RETURN_IF_ERROR(ApplySkipSimplifiedLayerNorm(a, skip, norm_scale, kSkipSimplifiedLayerNormEpsilon, + context, &normalized_a, input_skip_bias_sum)); + return ApplyUnfusedSiluMul(&normalized_a, + gate_b, + gate_scales, + gate_bias, + up_b, + up_scales, + up_bias, + K_, + N_, + block_size_, + accuracy_level_, + bits_, + context, + y); + } - const uint32_t vec_size = (data_size + 3u) / 4u; - MatMulNBitsSiluMulProgram program; - program - .AddInputs({{&gate_output, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, 4}, - {&up_output, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, 4}}) - .AddOutput({y, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) - .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .AddUniformVariables({vec_size}); + if (norm_scale != nullptr) { + Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape()); + ORT_RETURN_IF_ERROR(ApplySimplifiedLayerNorm(a, norm_scale, kSkipSimplifiedLayerNormEpsilon, context, &normalized_a)); + return ApplyUnfusedSiluMul(&normalized_a, + gate_b, + gate_scales, + gate_bias, + up_b, + up_scales, + up_bias, + K_, + N_, + block_size_, + accuracy_level_, + bits_, + context, + y); + } - return context.RunProgram(program); + return ApplyUnfusedSiluMul(a, + gate_b, + gate_scales, + gate_bias, + up_b, + up_scales, + up_bias, + K_, + N_, + block_size_, + accuracy_level_, + bits_, + context, + y); } } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul.wgsl.template index 25e3b869d0c21..190d06ba958e0 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul.wgsl.template @@ -9,6 +9,9 @@ #param has_full_n_tiles #param single_scale_weights #param sub_tile_count +#param has_norm_input +#param has_skip_input +#param has_skip_output #param tile_size_k_vec #param tile_size_k #param tile_size @@ -18,20 +21,53 @@ #use .getByOffset .setByOffset +#if has_norm_input +var sum_squared_shared : array; +#endif var tile_A : array; var gate_inter_results : array, tile_size>; var up_inter_results : array, tile_size>; const default_zero_point = output_element_t(8); -fn loadSHMA(batch: u32, kidx: u32, col: u32) +fn load_merged_input(input_offset: u32) -> input_a_value_t { +#if has_skip_input + let skip_offset = input_offset % (uniforms.skip_size / component_a); + return a.getByOffset(input_offset) + input_a_value_t(skip.getByOffset(skip_offset)); +#else + return a.getByOffset(input_offset); +#endif +} + +fn loadSHMA(batch: u32, b_global_base: u32, kidx: u32, col: u32, inv_std: f32) { let k_offset = kidx / component_a + col; + let input_offset = batch * uniforms.K_of_a + k_offset; #if has_full_k_tiles - tile_A[col] = a.getByOffset(batch * uniforms.K_of_a + k_offset); + let merged_value = load_merged_input(input_offset); +#if has_skip_output + if (b_global_base == 0u) { + input_skip_bias_sum.setByOffset(input_offset, input_skip_bias_sum_value_t(merged_value)); + } +#endif +#if has_norm_input + tile_A[col] = merged_value * input_a_value_t(input_a_element_t(inv_std)) * norm_scale.getByOffset(k_offset); +#else + tile_A[col] = merged_value; +#endif #else if (k_offset < uniforms.K_of_a) { - tile_A[col] = a.getByOffset(batch * uniforms.K_of_a + k_offset); + let merged_value = load_merged_input(input_offset); +#if has_skip_output + if (b_global_base == 0u) { + input_skip_bias_sum.setByOffset(input_offset, input_skip_bias_sum_value_t(merged_value)); + } +#endif +#if has_norm_input + tile_A[col] = merged_value * input_a_value_t(input_a_element_t(inv_std)) * norm_scale.getByOffset(k_offset); +#else + tile_A[col] = merged_value; +#endif } else { tile_A[col] = input_a_value_t(0); } @@ -110,10 +146,10 @@ fn compute_gate_up_sums(b_global: u32, kidx: u32, idx: u32, k_offset: u32) -> ve return vec2(gate_sum, up_sum); } -fn process_k_tile(batch: u32, b_global_base: u32, thread_idx: u32, idx: u32, idy: u32, kidx: u32) { +fn process_k_tile(batch: u32, b_global_base: u32, thread_idx: u32, idx: u32, idy: u32, kidx: u32, inv_std: f32) { for (var id = thread_idx; id < a_length_per_tile; id += workgroup_size_x) { - loadSHMA(batch, kidx, id); + loadSHMA(batch, b_global_base, kidx, id, inv_std); } workgroupBarrier(); @@ -165,6 +201,38 @@ $MAIN { } workgroupBarrier(); +#if has_norm_input + var sum_squared_local = 0.0; + for (var a_idx = local_idx; a_idx < uniforms.K_of_a; a_idx += workgroup_size_x) { + let a_value = load_merged_input(batch * uniforms.K_of_a + a_idx); +#if component_a == 1 + let a_f32 = f32(a_value); + sum_squared_local += a_f32 * a_f32; +#elif component_a == 2 + let a_f32 = vec2(a_value); + sum_squared_local += dot(a_f32, a_f32); +#elif component_a == 4 + let a_f32 = vec4(a_value); + sum_squared_local += dot(a_f32, a_f32); +#endif + } + sum_squared_shared[local_idx] = sum_squared_local; + workgroupBarrier(); + + var reduce_size : u32 = workgroup_size_x; + for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) { + reduce_size = curr_size + (reduce_size & 1u); + if (local_idx < curr_size) { + sum_squared_shared[local_idx] += sum_squared_shared[local_idx + reduce_size]; + } + workgroupBarrier(); + } + + let inv_std = inverseSqrt(sum_squared_shared[0] / f32(uniforms.K) + uniforms.epsilon); +#else + let inv_std = 1.0; +#endif + #if single_scale_weights let gate_scale_b = gate_scales_b.getByOffset(0); let up_scale_b = up_scales_b.getByOffset(0); @@ -173,29 +241,29 @@ $MAIN { #if k_unroll_tiles == 1 for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k) { - process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); } #elif k_unroll_tiles == 2 let unrolled_k_step = tile_size_k * 2u; let unrolled_k_limit = uniforms.K - (uniforms.K % unrolled_k_step); for (var kidx = 0u; kidx < unrolled_k_limit; kidx += unrolled_k_step) { - process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx); - process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k, inv_std); } for (var kidx = unrolled_k_limit; kidx < uniforms.K; kidx += tile_size_k) { - process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); } #elif k_unroll_tiles == 4 let unrolled_k_step = tile_size_k * 4u; let unrolled_k_limit = uniforms.K - (uniforms.K % unrolled_k_step); for (var kidx = 0u; kidx < unrolled_k_limit; kidx += unrolled_k_step) { - process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx); - process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k); - process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k * 2u); - process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k * 3u); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k, inv_std); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k * 2u, inv_std); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k * 3u, inv_std); } for (var kidx = unrolled_k_limit; kidx < uniforms.K; kidx += tile_size_k) { - process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); } #endif diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index a1aba7b9475f1..e6a9ed43460c8 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3623,9 +3623,23 @@ MatMulNBitsSiluMul fuses two MatMulNBits projections that share the same input a where SiLU(x) = x * sigmoid(x). +It can also optionally fuse SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization before the +two projections: + + A_norm = SimplifiedLayerNormalization(A, norm_scale, epsilon) + Y = SiLU(MatMulNBits(A_norm, gate_weight) + gate_bias) * (MatMulNBits(A_norm, up_weight) + up_bias) + + A_norm = SkipSimplifiedLayerNormalization(A, skip, norm_scale, epsilon) + Y = SiLU(MatMulNBits(A_norm, gate_weight) + gate_bias) * (MatMulNBits(A_norm, up_weight) + up_bias) + This operator is intended for decoder MLP patterns such as Qwen-style gate and up projections, but it remains semantically valid for both prefill and decode because the output shape is the standard MatMul result shape derived from the runtime shape of A and the shared attributes K and N. + +When fused from SkipSimplifiedLayerNormalization, the optional residual-sum output may also be materialized: + + A_norm, input_skip_bias_sum = SkipSimplifiedLayerNormalization(A, skip, norm_scale, epsilon) + Y = SiLU(MatMulNBits(A_norm, gate_weight) + gate_bias) * (MatMulNBits(A_norm, up_weight) + up_bias) )DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBitsSiluMul) @@ -3642,24 +3656,60 @@ derived from the runtime shape of A and the shared attributes K and N. "The minimum accuracy level of input A. It follows the same semantics as MatMulNBits.", AttributeProto::INT, static_cast(0)) .Input(0, "A", "The shared input tensor.", "T1") - .Input(1, "gate_B", "Packed uint8 tensor for the gate projection weights.", "T2") - .Input(2, "gate_scales", "Per-block scaling factors for the gate projection.", "T1") - .Input(3, "gate_bias", "Optional bias for the gate projection with shape [N].", "T1", OpSchema::Optional) - .Input(4, "up_B", "Packed uint8 tensor for the up projection weights.", "T2") - .Input(5, "up_scales", "Per-block scaling factors for the up projection.", "T1") - .Input(6, "up_bias", "Optional bias for the up projection with shape [N].", "T1", OpSchema::Optional) + .Input(1, "skip", "Optional skip input used by SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional) + .Input(2, "norm_scale", "Optional RMSNorm scale with shape [K] used by SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional) + .Input(3, "gate_B", "Packed uint8 tensor for the gate projection weights.", "T2") + .Input(4, "gate_scales", "Per-block scaling factors for the gate projection.", "T1") + .Input(5, "gate_bias", "Optional bias for the gate projection with shape [N].", "T1", OpSchema::Optional) + .Input(6, "up_B", "Packed uint8 tensor for the up projection weights.", "T2") + .Input(7, "up_scales", "Per-block scaling factors for the up projection.", "T1") + .Input(8, "up_bias", "Optional bias for the up projection with shape [N].", "T1", OpSchema::Optional) .Output(0, "Y", "The fused SiLU-multiply output tensor.", "T1") + .Output(1, "input_skip_bias_sum", "Optional residual-sum output for SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional) .TypeConstraint("T1", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (ctx.getNumOutputs() > 1) { + propagateElemTypeFromInputToOutput(ctx, 0, 1); + } const int64_t in_features = getAttribute(ctx, "K", -1); const int64_t out_features = getAttribute(ctx, "N", -1); MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, true); - for (size_t bias_input_index : {3U, 6U}) { + if (ctx.hasInput(1) && !ctx.hasInput(2)) { + fail_shape_inference("norm_scale input must be present when skip input is provided"); + } + + if (ctx.hasOutput(1)) { + if (!ctx.hasInput(1)) { + fail_shape_inference("skip input must be present when input_skip_bias_sum output is requested"); + } + + if (!hasInputShape(ctx, 0)) { + return; + } + + auto* skip_sum_shape = getOutputShape(ctx, 1); + *skip_sum_shape = getInputShape(ctx, 0); + } + + if (ctx.hasInput(2)) { + if (!hasInputShape(ctx, 2)) { + fail_shape_inference("norm_scale shape must be known"); + } + + const auto& norm_scale_shape = getInputShape(ctx, 2); + if (norm_scale_shape.dim_size() != 1 || + !norm_scale_shape.dim(0).has_dim_value() || + norm_scale_shape.dim(0).dim_value() != in_features) { + fail_shape_inference("norm_scale shape must be [K] where K = ", in_features); + } + } + + for (size_t bias_input_index : {5U, 8U}) { if (!ctx.hasInput(static_cast(bias_input_index))) { continue; } diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 46c2dcec1543a..741781e2d0f18 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -110,12 +110,12 @@ constexpr const char* kOrtEnableMatMulNBitsQKVSimplifiedLayerNormFusionEnvVar = #if !defined(ORT_MINIMAL_BUILD) bool IsMatMulNBitsSiluFusionEnabled() { - return ParseEnvironmentVariableWithDefault(kOrtEnableMatMulNBitsSiluFusionEnvVar, 1) == 1; + return ParseEnvironmentVariableWithDefault(kOrtEnableMatMulNBitsSiluFusionEnvVar, 0) == 1; //return true; } bool IsMatMulNBitsQKVSimplifiedLayerNormFusionEnabled() { - return ParseEnvironmentVariableWithDefault(kOrtEnableMatMulNBitsQKVSimplifiedLayerNormFusionEnvVar, 1) == 1; + return ParseEnvironmentVariableWithDefault(kOrtEnableMatMulNBitsQKVSimplifiedLayerNormFusionEnvVar, 0) == 1; //return true; } #endif diff --git a/onnxruntime/core/optimizer/matmul_nbits_silu_fusion.cc b/onnxruntime/core/optimizer/matmul_nbits_silu_fusion.cc index f17aa28a3c79d..28b96a2bae7b6 100644 --- a/onnxruntime/core/optimizer/matmul_nbits_silu_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_nbits_silu_fusion.cc @@ -3,6 +3,7 @@ #include "core/optimizer/matmul_nbits_silu_fusion.h" +#include #include #include "core/graph/graph_utils.h" @@ -30,6 +31,51 @@ bool IsSupportedSigmoid(const Node& node) { return graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}); } +bool IsSupportedSimplifiedLayerNormalization(const Node& node) { + return graph_utils::IsSupportedOptypeVersionAndDomain(node, "SimplifiedLayerNormalization", {1}); +} + +bool IsSupportedSkipSimplifiedLayerNormalization(const Node& node) { + return graph_utils::IsSupportedOptypeVersionAndDomain(node, "SkipSimplifiedLayerNormalization", {1}, kMSDomain); +} + +bool IsSupportedSiluNormAnchor(const Node& node) { + return IsSupportedSimplifiedLayerNormalization(node) || IsSupportedSkipSimplifiedLayerNormalization(node); +} + +bool HasProducedOutput(const Node& node, size_t index) { + return index < node.OutputDefs().size() && node.OutputDefs()[index] != nullptr && !node.OutputDefs()[index]->Name().empty(); +} + +bool ProducesOnlyOptionalSkipOutputAsGraphOutput(const Graph& graph, const Node& node) { + const auto graph_outputs = graph.GetNodeOutputsInGraphOutputs(node); + return std::all_of(graph_outputs.begin(), graph_outputs.end(), [](int output_idx) { return output_idx == 3; }); +} + +size_t ExpectedNormConsumerEdgeCount(const Node& node) { + return 2u + ((IsSupportedSkipSimplifiedLayerNormalization(node) && HasProducedOutput(node, 3)) ? 1u : 0u); +} + +bool HasExpectedNormConsumers(const Graph& graph, const Node& node) { + const auto graph_outputs = graph.GetNodeOutputsInGraphOutputs(node); + const size_t expected_output_edges = ExpectedNormConsumerEdgeCount(node) - graph_outputs.size(); + if (node.GetOutputEdgesCount() != expected_output_edges) { + return false; + } + + // Match optimizer_utils::CheckOutputEdges safety check while allowing output 3 to be a graph output. + for (auto output_edge_it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); output_edge_it != end; ++output_edge_it) { + const auto& output_node = output_edge_it->GetNode(); + const auto output_node_input_arg_idx = static_cast(output_edge_it->GetDstArgIndex()); + const bool is_implicit_input_to_output_node = output_node_input_arg_idx >= output_node.InputDefs().size(); + if (is_implicit_input_to_output_node) { + return false; + } + } + + return true; +} + bool IsMatMulNBitsWithoutZeroPointOrGroupIdx(const Node& node) { return graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMulNBits", {1}, kMSDomain) && !HasInput(node, 3) && !HasInput(node, 4); @@ -49,6 +95,40 @@ bool HasSingleNonGraphConsumer(const Graph& graph, const Node& node) { return !graph.NodeProducesGraphOutput(node) && optimizer_utils::CheckOutputEdges(graph, node, 1); } +const Node* GetOptionalNormProducer(const Graph& graph, + const Node& gate_matmul, + const Node& up_matmul) { + if (gate_matmul.InputDefs().empty() || up_matmul.InputDefs().empty() || + gate_matmul.InputDefs()[0] != up_matmul.InputDefs()[0]) { + return nullptr; + } + + const Node* gate_input = GetInputNode(graph, gate_matmul, 0); + const Node* up_input = GetInputNode(graph, up_matmul, 0); + if (gate_input == nullptr || gate_input != up_input || !IsSupportedSiluNormAnchor(*gate_input)) { + return nullptr; + } + + if (!HasProducedOutput(*gate_input, 0)) { + return nullptr; + } + + if (graph.NodeProducesGraphOutput(*gate_input) && !ProducesOnlyOptionalSkipOutputAsGraphOutput(graph, *gate_input)) { + return nullptr; + } + + if (!HasExpectedNormConsumers(graph, *gate_input)) { + return nullptr; + } + + const size_t min_norm_inputs = IsSupportedSkipSimplifiedLayerNormalization(*gate_input) ? 3u : 2u; + if (gate_input->InputDefs().size() < min_norm_inputs) { + return nullptr; + } + + return gate_input; +} + bool IsFuseCandidate(const Graph& graph, const Node& gate_matmul, const Node& up_matmul, @@ -196,6 +276,12 @@ Status MatMulNBitsSiluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ continue; } + const Node* norm = GetOptionalNormProducer(graph, *gate_matmul, *up_matmul); + if (norm != nullptr && + !norm->GetExecutionProviderType().empty() && norm->GetExecutionProviderType() != kWebGpuExecutionProvider) { + continue; + } + NodeAttributes attrs; utils::SetNodeAttribute(utils::MakeAttribute("K", GetIntAttr(*gate_matmul, "K", -1, true)), attrs); utils::SetNodeAttribute(utils::MakeAttribute("N", GetIntAttr(*gate_matmul, "N", -1, true)), attrs); @@ -204,9 +290,12 @@ Status MatMulNBitsSiluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", GetIntAttr(*gate_matmul, "accuracy_level", 0)), attrs); NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr); + const bool is_skip_sln = norm != nullptr && IsSupportedSkipSimplifiedLayerNormalization(*norm); InlinedVector fused_inputs{ - const_cast(gate_matmul->InputDefs()[0]), + const_cast(norm != nullptr ? norm->InputDefs()[0] : gate_matmul->InputDefs()[0]), + is_skip_sln ? const_cast(norm->InputDefs()[1]) : &empty_arg, + norm != nullptr ? const_cast(norm->InputDefs()[is_skip_sln ? 2 : 1]) : &empty_arg, const_cast(gate_matmul->InputDefs()[1]), const_cast(gate_matmul->InputDefs()[2]), HasInput(*gate_matmul, 5) ? const_cast(gate_matmul->InputDefs()[5]) : &empty_arg, @@ -215,25 +304,91 @@ Status MatMulNBitsSiluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ HasInput(*up_matmul, 5) ? const_cast(up_matmul->InputDefs()[5]) : &empty_arg, }; + InlinedVector fused_outputs{const_cast(node.OutputDefs()[0])}; + const bool preserve_skip_output = is_skip_sln && norm != nullptr && HasProducedOutput(*norm, 3); + if (preserve_skip_output) { + fused_outputs.push_back(const_cast(norm->OutputDefs()[3])); + } + + const auto norm_input_edges = norm != nullptr ? graph_utils::GraphEdge::GetNodeInputEdges(*norm) + : std::vector{}; + const auto gate_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*gate_matmul); + const auto up_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*up_matmul); + const auto final_mul_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node); + const auto norm_output_edges = preserve_skip_output ? graph_utils::GraphEdge::GetNodeOutputEdges(*norm) + : std::vector{}; + + if (norm != nullptr) { + graph_utils::RemoveNodeOutputEdges(graph, const_cast(*norm)); + graph.RemoveNode(norm->Index()); + } + graph_utils::RemoveNodeOutputEdges(graph, const_cast(*gate_matmul)); + graph.RemoveNode(gate_matmul->Index()); + graph_utils::RemoveNodeOutputEdges(graph, const_cast(*up_matmul)); + graph.RemoveNode(up_matmul->Index()); + graph_utils::RemoveNodeOutputEdges(graph, const_cast(*sigmoid)); + graph.RemoveNode(sigmoid->Index()); + graph_utils::RemoveNodeOutputEdges(graph, const_cast(*silu_mul)); + graph.RemoveNode(silu_mul->Index()); + graph_utils::RemoveNodeOutputEdges(graph, node); + graph.RemoveNode(node.Index()); + Node& fused_node = graph.AddNode(graph.GenerateNodeName("MatMulNBitsSiluMul"), "MatMulNBitsSiluMul", "fused MatMulNBits gate/up projections with SiLU multiply", fused_inputs, - {}, + fused_outputs, &attrs, kMSDomain); fused_node.SetExecutionProviderType(kWebGpuExecutionProvider); LOGS(logger, INFO) << "MatMulNBitsSiluFusion: created fused node '" << fused_node.Name() - << "' from final_mul='" << node.Name() << "'"; - - graph_utils::FinalizeNodeFusion(graph, - {std::ref(const_cast(*gate_matmul)), - std::ref(const_cast(*up_matmul)), - std::ref(const_cast(*sigmoid)), - std::ref(const_cast(*silu_mul)), - std::ref(node)}, - fused_node); + << "' from final_mul='" << node.Name() << "'"; + + if (norm != nullptr) { + for (const auto& input_edge : norm_input_edges) { + int fused_input_index = input_edge.dst_arg_index; + if (!is_skip_sln && input_edge.dst_arg_index == 1) { + fused_input_index = 2; + } + + graph.AddEdge(input_edge.src_node, fused_node.Index(), input_edge.src_arg_index, fused_input_index); + } + } else { + for (const auto& input_edge : gate_input_edges) { + if (input_edge.dst_arg_index == 0) { + graph.AddEdge(input_edge.src_node, fused_node.Index(), input_edge.src_arg_index, 0); + } + } + } + + auto add_input_edge_if_present = [&](const std::vector& edges, + int source_input_index, + int fused_input_index) { + for (const auto& input_edge : edges) { + if (input_edge.dst_arg_index == source_input_index) { + graph.AddEdge(input_edge.src_node, fused_node.Index(), input_edge.src_arg_index, fused_input_index); + } + } + }; + + add_input_edge_if_present(gate_input_edges, 1, 3); + add_input_edge_if_present(gate_input_edges, 2, 4); + add_input_edge_if_present(gate_input_edges, 5, 5); + add_input_edge_if_present(up_input_edges, 1, 6); + add_input_edge_if_present(up_input_edges, 2, 7); + add_input_edge_if_present(up_input_edges, 5, 8); + + for (const auto& output_edge : final_mul_output_edges) { + graph.AddEdge(fused_node.Index(), output_edge.dst_node, 0, output_edge.dst_arg_index); + } + if (preserve_skip_output) { + for (const auto& output_edge : norm_output_edges) { + if (output_edge.src_arg_index == 3) { + graph.AddEdge(fused_node.Index(), output_edge.dst_node, 1, output_edge.dst_arg_index); + } + } + } modified = true; } diff --git a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc index 193e01c4c6aa8..7bc5c19c191c5 100644 --- a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc +++ b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc @@ -63,6 +63,13 @@ struct DecodeBenchConfig { enum class MlpDecodeBenchmarkVariant { kUnfused, kFused, + kSkipNormThenFused, + kSkipNormPassthroughThenFused, +}; + +enum class MlpNormKind { + kNone, + kSkipSimplified, }; struct MlpDecodeBenchConfig { @@ -122,6 +129,8 @@ struct DecodeTrafficStats { struct MlpTrafficStats { double input_bytes; + double skip_input_bytes; + double norm_scale_bytes; double packed_weight_bytes; double scale_bytes; double intermediate_bytes; @@ -350,26 +359,44 @@ DecodeTrafficStats CalculateDecodeTrafficStats(const DecodeBenchConfig& config) }; } -MlpTrafficStats CalculateMlpTrafficStats(const MlpDecodeBenchConfig& config, MlpDecodeBenchmarkVariant variant) { +MlpTrafficStats CalculateMlpTrafficStats(const MlpDecodeBenchConfig& config, + MlpDecodeBenchmarkVariant variant, + MlpNormKind norm_kind) { const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; const int64_t blob_size = (config.block_size * config.bits) / 8; - const double input_reads = variant == MlpDecodeBenchmarkVariant::kFused ? 1.0 : 2.0; + const bool is_unfused = variant == MlpDecodeBenchmarkVariant::kUnfused; + const bool is_skip_norm_then_fused = variant == MlpDecodeBenchmarkVariant::kSkipNormThenFused; + const bool is_skip_norm_passthrough_then_fused = variant == MlpDecodeBenchmarkVariant::kSkipNormPassthroughThenFused; + const bool has_skip_norm = norm_kind == MlpNormKind::kSkipSimplified; + const double input_reads = variant == MlpDecodeBenchmarkVariant::kUnfused ? 2.0 : 1.0; const double intermediate_bytes = - variant == MlpDecodeBenchmarkVariant::kFused ? 0.0 : 4.0 * static_cast(config.n) * sizeof(Ort::Float16_t); + (is_unfused ? 4.0 * static_cast(config.n) * sizeof(Ort::Float16_t) : 0.0) + + ((is_unfused || is_skip_norm_then_fused) && has_skip_norm + ? static_cast(config.k) * sizeof(Ort::Float16_t) + : 0.0); const double input_bytes = input_reads * static_cast(config.k) * sizeof(Ort::Float16_t); + const double skip_input_bytes = + has_skip_norm ? static_cast(config.k) * sizeof(Ort::Float16_t) : 0.0; + const double norm_scale_bytes = + has_skip_norm ? static_cast(config.k) * sizeof(Ort::Float16_t) : 0.0; const double packed_weight_bytes = 2.0 * static_cast(config.n) * static_cast(k_blocks) * static_cast(blob_size); const double scale_bytes = 2.0 * static_cast(config.n) * static_cast(k_blocks) * sizeof(Ort::Float16_t); - const double output_bytes = static_cast(config.n) * sizeof(Ort::Float16_t); + const double output_bytes = + (static_cast(config.n) + + (is_skip_norm_passthrough_then_fused && has_skip_norm ? static_cast(config.k) : 0.0)) * + sizeof(Ort::Float16_t); return { input_bytes, + skip_input_bytes, + norm_scale_bytes, packed_weight_bytes, scale_bytes, intermediate_bytes, output_bytes, - input_bytes + packed_weight_bytes + scale_bytes + intermediate_bytes + output_bytes, + input_bytes + skip_input_bytes + norm_scale_bytes + packed_weight_bytes + scale_bytes + intermediate_bytes + output_bytes, }; } @@ -713,11 +740,14 @@ void AddMatMulNBitsNode(ONNX_NAMESPACE::GraphProto& graph, void AddMatMulNBitsSiluMulNode(ONNX_NAMESPACE::GraphProto& graph, const std::string& node_name, const std::string& input_name, + const std::string& skip_input_name, + const std::string& norm_scale_name, const std::string& gate_weight_name, const std::string& gate_scale_name, const std::string& up_weight_name, const std::string& up_scale_name, const std::string& output_name, + const std::string& skip_sum_output_name, int64_t k, int64_t n, int64_t bits, @@ -728,6 +758,8 @@ void AddMatMulNBitsSiluMulNode(ONNX_NAMESPACE::GraphProto& graph, node->set_op_type("MatMulNBitsSiluMul"); node->set_domain("com.microsoft"); node->add_input(input_name); + node->add_input(skip_input_name); + node->add_input(norm_scale_name); node->add_input(gate_weight_name); node->add_input(gate_scale_name); node->add_input(""); @@ -735,6 +767,9 @@ void AddMatMulNBitsSiluMulNode(ONNX_NAMESPACE::GraphProto& graph, node->add_input(up_scale_name); node->add_input(""); node->add_output(output_name); + if (!skip_sum_output_name.empty()) { + node->add_output(skip_sum_output_name); + } auto* attr_k = node->add_attribute(); attr_k->set_name("K"); @@ -888,12 +923,27 @@ std::vector SerializeMatMulNBitsModel(const DecodeBenchConfig& config) } std::string GetMlpVariantLabel(MlpDecodeBenchmarkVariant variant) { - return variant == MlpDecodeBenchmarkVariant::kFused ? "fused" : "unfused"; + switch (variant) { + case MlpDecodeBenchmarkVariant::kUnfused: + return "unfused"; + case MlpDecodeBenchmarkVariant::kFused: + return "fused"; + case MlpDecodeBenchmarkVariant::kSkipNormThenFused: + return "skip_norm_then_fused"; + case MlpDecodeBenchmarkVariant::kSkipNormPassthroughThenFused: + return "skip_norm_passthrough_then_fused"; + } + + return "unknown"; } -std::string GetMlpDecodeBenchmarkLabel(MlpDecodeBenchmarkVariant variant) { +std::string GetMlpNormKindLabel(MlpNormKind norm_kind) { + return norm_kind == MlpNormKind::kSkipSimplified ? "skip_simplified" : "plain"; +} + +std::string GetMlpDecodeBenchmarkLabel(MlpDecodeBenchmarkVariant variant, MlpNormKind norm_kind) { std::ostringstream stream; - stream << "fp16_mlp_decode_" << GetMlpVariantLabel(variant) << '_' + stream << "fp16_mlp_decode_" << GetMlpNormKindLabel(norm_kind) << '_' << GetMlpVariantLabel(variant) << '_' << (IsDecodeBenchmarkPerfMode() ? "perf" : "correctness") << '_' << (GetDecodeBenchmarkGpu() == DecodeBenchmarkGpu::kRtx5060Ti ? "rtx" : "t") << '_' << (IsMatMulNBitsAutoTunerEnabled() ? "tuner_on" : "tuner_off") << '_' @@ -920,7 +970,8 @@ std::string GetQkvDecodeBenchmarkLabel(QkvDecodeBenchmarkVariant variant, QkvNor } std::vector SerializeMatMulNBitsMlpModel(const MlpDecodeBenchConfig& config, - MlpDecodeBenchmarkVariant variant) { + MlpDecodeBenchmarkVariant variant, + MlpNormKind norm_kind) { const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; const int64_t blob_size = (config.block_size * config.bits) / 8; @@ -935,8 +986,21 @@ std::vector SerializeMatMulNBitsMlpModel(const MlpDecodeBenchConfig& co ms_opset->set_version(1); auto* graph = model.mutable_graph(); - graph->set_name(variant == MlpDecodeBenchmarkVariant::kFused ? "WebGpuMatMulNBitsMlpDecodeFused" - : "WebGpuMatMulNBitsMlpDecodeUnfused"); + switch (variant) { + case MlpDecodeBenchmarkVariant::kFused: + graph->set_name("WebGpuMatMulNBitsMlpDecodeFused"); + break; + case MlpDecodeBenchmarkVariant::kSkipNormThenFused: + graph->set_name("WebGpuMatMulNBitsMlpSkipNormThenFused"); + break; + case MlpDecodeBenchmarkVariant::kSkipNormPassthroughThenFused: + graph->set_name("WebGpuMatMulNBitsMlpSkipNormPassthroughThenFused"); + break; + case MlpDecodeBenchmarkVariant::kUnfused: + default: + graph->set_name("WebGpuMatMulNBitsMlpDecodeUnfused"); + break; + } auto* input = graph->add_input(); input->set_name("A"); @@ -944,6 +1008,14 @@ std::vector SerializeMatMulNBitsMlpModel(const MlpDecodeBenchConfig& co input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.k); + if (norm_kind == MlpNormKind::kSkipSimplified) { + auto* skip_input = graph->add_input(); + skip_input->set_name("Skip"); + skip_input->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + skip_input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + skip_input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.k); + } + auto* output = graph->add_output(); output->set_name("Y"); output->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); @@ -954,6 +1026,7 @@ std::vector SerializeMatMulNBitsMlpModel(const MlpDecodeBenchConfig& co std::vector up_b(static_cast(config.n * k_blocks * blob_size), uint8_t{0x77}); std::vector gate_scales(static_cast(config.n * k_blocks), Ort::Float16_t(0.03125f)); std::vector up_scales(static_cast(config.n * k_blocks), Ort::Float16_t(0.0625f)); + std::vector norm_scale(static_cast(config.k), Ort::Float16_t(1.0f)); AddTensorInitializer(*graph, "gate_B", ONNX_NAMESPACE::TensorProto_DataType_UINT8, {config.n, k_blocks, blob_size}, gate_b); @@ -963,22 +1036,105 @@ std::vector SerializeMatMulNBitsMlpModel(const MlpDecodeBenchConfig& co {config.n, k_blocks}, gate_scales); AddTensorInitializer(*graph, "up_scales", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {config.n, k_blocks}, up_scales); + if (norm_kind == MlpNormKind::kSkipSimplified) { + AddTensorInitializer(*graph, "norm_scale", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + {config.k}, norm_scale); + } if (variant == MlpDecodeBenchmarkVariant::kFused) { AddMatMulNBitsSiluMulNode(*graph, "MatMulNBitsSiluMulDecode", "A", + norm_kind == MlpNormKind::kSkipSimplified ? "Skip" : "", + norm_kind == MlpNormKind::kSkipSimplified ? "norm_scale" : "", + "gate_B", + "gate_scales", + "up_B", + "up_scales", + "Y", + "", + config.k, + config.n, + config.bits, + config.block_size, + config.accuracy_level); + } else if (variant == MlpDecodeBenchmarkVariant::kSkipNormThenFused) { + ORT_ENFORCE(norm_kind == MlpNormKind::kSkipSimplified, + "SkipNormThenFused benchmark variant requires SkipSimplified norm kind."); + AddTensorValueInfo(*graph, "A_norm", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.k}); + auto* norm = graph->add_node(); + norm->set_name("InputSkipSimplifiedLayerNorm"); + norm->set_op_type("SkipSimplifiedLayerNormalization"); + norm->set_domain("com.microsoft"); + norm->add_input("A"); + norm->add_input("Skip"); + norm->add_input("norm_scale"); + norm->add_output("A_norm"); + auto* attr_epsilon = norm->add_attribute(); + attr_epsilon->set_name("epsilon"); + attr_epsilon->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT); + attr_epsilon->set_f(1e-6f); + + AddMatMulNBitsSiluMulNode(*graph, + "MatMulNBitsSiluMulDecodeAfterSkipNorm", + "A_norm", + "", + "", + "gate_B", + "gate_scales", + "up_B", + "up_scales", + "Y", + "", + config.k, + config.n, + config.bits, + config.block_size, + config.accuracy_level); + } else if (variant == MlpDecodeBenchmarkVariant::kSkipNormPassthroughThenFused) { + ORT_ENFORCE(norm_kind == MlpNormKind::kSkipSimplified, + "SkipNormPassthroughThenFused benchmark variant requires SkipSimplified norm kind."); + + auto* skip_sum_output = graph->add_output(); + skip_sum_output->set_name("SkipOut"); + skip_sum_output->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + skip_sum_output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + skip_sum_output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.k); + + AddMatMulNBitsSiluMulNode(*graph, + "MatMulNBitsSiluMulDecodeWithSkipNormPassthrough", + "A", + "Skip", + "norm_scale", "gate_B", "gate_scales", "up_B", "up_scales", "Y", + "SkipOut", config.k, config.n, config.bits, config.block_size, config.accuracy_level); } else { + const char* mlp_input_name = norm_kind == MlpNormKind::kSkipSimplified ? "A_norm" : "A"; + if (norm_kind == MlpNormKind::kSkipSimplified) { + AddTensorValueInfo(*graph, "A_norm", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.k}); + auto* norm = graph->add_node(); + norm->set_name("InputSkipSimplifiedLayerNorm"); + norm->set_op_type("SkipSimplifiedLayerNormalization"); + norm->set_domain("com.microsoft"); + norm->add_input("A"); + norm->add_input("Skip"); + norm->add_input("norm_scale"); + norm->add_output("A_norm"); + auto* attr_epsilon = norm->add_attribute(); + attr_epsilon->set_name("epsilon"); + attr_epsilon->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT); + attr_epsilon->set_f(1e-6f); + } + AddTensorValueInfo(*graph, "gate_mm", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.n}); AddTensorValueInfo(*graph, "up_mm", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.n}); AddTensorValueInfo(*graph, "gate_sigmoid", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.n}); @@ -986,7 +1142,7 @@ std::vector SerializeMatMulNBitsMlpModel(const MlpDecodeBenchConfig& co AddMatMulNBitsNode(*graph, "GateMatMulNBitsDecode", - "A", + mlp_input_name, "gate_B", "gate_scales", "gate_mm", @@ -997,7 +1153,7 @@ std::vector SerializeMatMulNBitsMlpModel(const MlpDecodeBenchConfig& co config.accuracy_level); AddMatMulNBitsNode(*graph, "UpMatMulNBitsDecode", - "A", + mlp_input_name, "up_B", "up_scales", "up_mm", @@ -1233,7 +1389,8 @@ void ValidateMlpDecodeOutputs(const std::vector& unfused_model_data, const std::vector& fused_model_data, const std::unordered_map& provider_options, const char* const* input_names, - const Ort::Value* input_tensor, + const Ort::Value* input_tensors, + size_t input_count, const char* const* output_names) { Ort::Session unfused_session = CreateSessionFromModelData(unfused_model_data, &provider_options, @@ -1242,8 +1399,8 @@ void ValidateMlpDecodeOutputs(const std::vector& unfused_model_data, &provider_options, GraphOptimizationLevel::ORT_ENABLE_ALL); - auto unfused_outputs = unfused_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensor, 1, output_names, 1); - auto fused_outputs = fused_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensor, 1, output_names, 1); + auto unfused_outputs = unfused_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensors, input_count, output_names, 1); + auto fused_outputs = fused_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensors, input_count, output_names, 1); if (unfused_outputs.size() != 1 || fused_outputs.size() != 1) { throw std::runtime_error("Expected a single output from both unfused and fused MLP sessions."); @@ -1515,7 +1672,9 @@ void BenchmarkWebGpuMatMulNBitsQkvDecode(benchmark::State& state, QkvDecodeBench } } -void BenchmarkWebGpuMatMulNBitsMlpDecode(benchmark::State& state, MlpDecodeBenchmarkVariant variant) { +void BenchmarkWebGpuMatMulNBitsMlpDecode(benchmark::State& state, + MlpDecodeBenchmarkVariant variant, + MlpNormKind norm_kind) { try { const MlpDecodeBenchConfig config{ state.range(0), @@ -1530,53 +1689,67 @@ void BenchmarkWebGpuMatMulNBitsMlpDecode(benchmark::State& state, MlpDecodeBench return; } - const MlpTrafficStats traffic = CalculateMlpTrafficStats(config, variant); - std::vector model_data = SerializeMatMulNBitsMlpModel(config, variant); + const MlpTrafficStats traffic = CalculateMlpTrafficStats(config, variant, norm_kind); + std::vector model_data = SerializeMatMulNBitsMlpModel(config, variant, norm_kind); const SelectedWebGpuContext& selected_context = GetSelectedWebGpuContext(); const GraphOptimizationLevel optimization_level = variant == MlpDecodeBenchmarkVariant::kUnfused ? GraphOptimizationLevel::ORT_DISABLE_ALL - : GraphOptimizationLevel::ORT_ENABLE_ALL; + : GraphOptimizationLevel::ORT_ENABLE_ALL; Ort::Session session = CreateSessionFromModelData(model_data, &selected_context.provider_options, optimization_level); Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); std::vector input_shape{1, config.k}; std::vector activation(static_cast(config.k)); + std::vector skip_activation(static_cast(config.k)); std::mt19937 rng(123); std::uniform_real_distribution dist(-1.0f, 1.0f); for (auto& value : activation) { value = Ort::Float16_t(dist(rng)); } + for (auto& value : skip_activation) { + value = Ort::Float16_t(dist(rng)); + } - const char* input_names[] = {"A"}; + const char* plain_input_names[] = {"A"}; + const char* skip_input_names[] = {"A", "Skip"}; + const char* const* input_names = norm_kind == MlpNormKind::kSkipSimplified ? skip_input_names : plain_input_names; const char* output_names[] = {"Y"}; + const size_t input_count = norm_kind == MlpNormKind::kSkipSimplified ? 2u : 1u; - auto input_tensor = Ort::Value::CreateTensor(memory_info, - activation.data(), - activation.size(), - input_shape.data(), - input_shape.size()); + std::array input_tensors = { + Ort::Value::CreateTensor(memory_info, + activation.data(), + activation.size(), + input_shape.data(), + input_shape.size()), + Ort::Value::CreateTensor(memory_info, + skip_activation.data(), + skip_activation.size(), + input_shape.data(), + input_shape.size())}; Ort::RunOptions run_options = CreateBenchmarkRunOptions(); if (!IsDecodeBenchmarkPerfMode()) { - ValidateMlpDecodeOutputs(SerializeMatMulNBitsMlpModel(config, MlpDecodeBenchmarkVariant::kUnfused), - SerializeMatMulNBitsMlpModel(config, MlpDecodeBenchmarkVariant::kFused), + ValidateMlpDecodeOutputs(SerializeMatMulNBitsMlpModel(config, MlpDecodeBenchmarkVariant::kUnfused, norm_kind), + SerializeMatMulNBitsMlpModel(config, variant, norm_kind), selected_context.provider_options, input_names, - &input_tensor, + input_tensors.data(), + input_count, output_names); } for (int i = 0; i < kDecodeWarmupRuns; ++i) { - auto warmup_outputs = session.Run(run_options, input_names, &input_tensor, 1, output_names, 1); + auto warmup_outputs = session.Run(run_options, input_names, input_tensors.data(), input_count, output_names, 1); benchmark::DoNotOptimize(warmup_outputs); } double total_kernel_seconds = 0.0; for (auto _ : state) { const auto kernel_start = std::chrono::steady_clock::now(); - auto outputs = session.Run(run_options, input_names, &input_tensor, 1, output_names, 1); + auto outputs = session.Run(run_options, input_names, input_tensors.data(), input_count, output_names, 1); const auto kernel_end = std::chrono::steady_clock::now(); total_kernel_seconds += std::chrono::duration(kernel_end - kernel_start).count(); benchmark::DoNotOptimize(outputs); @@ -1588,13 +1761,15 @@ void BenchmarkWebGpuMatMulNBitsMlpDecode(benchmark::State& state, MlpDecodeBench ? traffic.total_bytes * static_cast(state.iterations()) / total_kernel_seconds : 0.0; - state.SetLabel(GetMlpDecodeBenchmarkLabel(variant)); + state.SetLabel(GetMlpDecodeBenchmarkLabel(variant, norm_kind)); state.counters["TFLOPS"] = benchmark::Counter( total_flops, benchmark::Counter::kIsIterationInvariantRate); state.counters["ApproxMemBW_GBps"] = benchmark::Counter(achieved_bandwidth_bytes_per_second / 1.0e9); state.counters["ApproxTraffic_MB"] = benchmark::Counter(traffic.total_bytes / 1.0e6); state.counters["Input_MB"] = benchmark::Counter(traffic.input_bytes / 1.0e6); + state.counters["SkipInput_MB"] = benchmark::Counter(traffic.skip_input_bytes / 1.0e6); + state.counters["NormScale_MB"] = benchmark::Counter(traffic.norm_scale_bytes / 1.0e6); state.counters["PackedW_MB"] = benchmark::Counter(traffic.packed_weight_bytes / 1.0e6); state.counters["Scales_MB"] = benchmark::Counter(traffic.scale_bytes / 1.0e6); state.counters["Intermediate_MB"] = benchmark::Counter(traffic.intermediate_bytes / 1.0e6); @@ -1606,11 +1781,27 @@ void BenchmarkWebGpuMatMulNBitsMlpDecode(benchmark::State& state, MlpDecodeBench } static void BM_WebGpuMatMulNBitsMlpDecodeUnfused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kUnfused); + BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kUnfused, MlpNormKind::kNone); } static void BM_WebGpuMatMulNBitsMlpDecodeFused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kFused); + BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kFused, MlpNormKind::kNone); +} + +static void BM_WebGpuMatMulNBitsMlpSkipDecodeUnfused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kUnfused, MlpNormKind::kSkipSimplified); +} + +static void BM_WebGpuMatMulNBitsMlpSkipDecodeFused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kFused, MlpNormKind::kSkipSimplified); +} + +static void BM_WebGpuMatMulNBitsMlpSkipDecodeSkipNormThenFused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kSkipNormThenFused, MlpNormKind::kSkipSimplified); +} + +static void BM_WebGpuMatMulNBitsMlpSkipDecodeSkipNormPassthroughThenFused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kSkipNormPassthroughThenFused, MlpNormKind::kSkipSimplified); } static void BM_WebGpuMatMulNBitsQkvDecodeUnfused(benchmark::State& state) { @@ -1656,42 +1847,60 @@ void ApplyWebGpuMatMulNBitsQkvDecodeArgs(benchmark::internal::Benchmark* benchma BENCHMARK(BM_WebGpuMatMulNBitsMlpDecodeUnfused) ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) - ->Repetitions(5) ->ReportAggregatesOnly() ->UseRealTime() ->Unit(benchmark::TimeUnit::kMicrosecond); BENCHMARK(BM_WebGpuMatMulNBitsMlpDecodeFused) ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) - ->Repetitions(5) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +BENCHMARK(BM_WebGpuMatMulNBitsMlpSkipDecodeUnfused) + ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +BENCHMARK(BM_WebGpuMatMulNBitsMlpSkipDecodeFused) + ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +BENCHMARK(BM_WebGpuMatMulNBitsMlpSkipDecodeSkipNormThenFused) + ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +BENCHMARK(BM_WebGpuMatMulNBitsMlpSkipDecodeSkipNormPassthroughThenFused) + ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) ->ReportAggregatesOnly() ->UseRealTime() ->Unit(benchmark::TimeUnit::kMicrosecond); BENCHMARK(BM_WebGpuMatMulNBitsQkvDecodeUnfused) ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) - ->Repetitions(5) ->ReportAggregatesOnly() ->UseRealTime() ->Unit(benchmark::TimeUnit::kMicrosecond); BENCHMARK(BM_WebGpuMatMulNBitsQkvDecodeFused) ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) - ->Repetitions(5) ->ReportAggregatesOnly() ->UseRealTime() ->Unit(benchmark::TimeUnit::kMicrosecond); BENCHMARK(BM_WebGpuMatMulNBitsQkvSkipDecodeUnfused) ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) - ->Repetitions(5) ->ReportAggregatesOnly() ->UseRealTime() ->Unit(benchmark::TimeUnit::kMicrosecond); BENCHMARK(BM_WebGpuMatMulNBitsQkvSkipDecodeFused) ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) - ->Repetitions(5) ->ReportAggregatesOnly() ->UseRealTime() ->Unit(benchmark::TimeUnit::kMicrosecond); diff --git a/onnxruntime/test/optimizer/matmul_nbits_silu_fusion_test.cc b/onnxruntime/test/optimizer/matmul_nbits_silu_fusion_test.cc index 6d48344e27faf..2a1886caa618e 100644 --- a/onnxruntime/test/optimizer/matmul_nbits_silu_fusion_test.cc +++ b/onnxruntime/test/optimizer/matmul_nbits_silu_fusion_test.cc @@ -21,6 +21,17 @@ namespace test { namespace { +enum class NormAnchorKind { + kNone, + kSimplified, + kSkipSimplified, +}; + +enum class SkipOutputKind { + kNone, + kGraphOutput, +}; + void SetWebGpuProvider(Node& node) { node.SetExecutionProviderType(kWebGpuExecutionProvider); } @@ -35,10 +46,12 @@ NodeAttributes MakeMatMulNBitsAttrs(int64_t k, int64_t n, int64_t block_size, in return attrs; } -Status CheckMatMulNBitsSiluFusedGraph(const Graph& graph) { +Status CheckMatMulNBitsSiluFusedGraphImpl(const Graph& graph, NormAnchorKind norm_anchor_kind) { const auto op_to_count = CountOpsInGraph(graph); if (OpCount(op_to_count, "com.microsoft.MatMulNBitsSiluMul") != 1 || OpCount(op_to_count, "com.microsoft.MatMulNBits") != 0 || + OpCount(op_to_count, "SimplifiedLayerNormalization") != 0 || + OpCount(op_to_count, "com.microsoft.SkipSimplifiedLayerNormalization") != 0 || OpCount(op_to_count, "Sigmoid") != 0 || OpCount(op_to_count, "Mul") != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected operator counts after MatMulNBitsSiluFusion."); @@ -49,13 +62,71 @@ Status CheckMatMulNBitsSiluFusedGraph(const Graph& graph) { ORT_RETURN_IF_NOT(node.Domain() == kMSDomain, "Fused node must be in com.microsoft domain."); ORT_RETURN_IF_NOT(node.GetExecutionProviderType() == kWebGpuExecutionProvider, "Fused node must be assigned to WebGPU EP."); + ORT_RETURN_IF_NOT(node.InputDefs().size() == 9u, "Fused node must have 9 inputs."); + const bool has_skip = node.InputDefs()[1] != nullptr && !node.InputDefs()[1]->Name().empty(); + const bool has_norm_scale = node.InputDefs()[2] != nullptr && !node.InputDefs()[2]->Name().empty(); + ORT_RETURN_IF_NOT(has_skip == (norm_anchor_kind == NormAnchorKind::kSkipSimplified), + "Unexpected skip input presence on fused node."); + ORT_RETURN_IF_NOT(has_norm_scale == (norm_anchor_kind != NormAnchorKind::kNone), + "Unexpected norm_scale input presence on fused node."); } } return Status::OK(); } -void BuildMatMulNBitsSiluWebGpuPattern(ModelTestBuilder& builder) { +Status CheckMatMulNBitsSiluFusedGraph(const Graph& graph) { + return CheckMatMulNBitsSiluFusedGraphImpl(graph, NormAnchorKind::kNone); +} + +Status CheckMatMulNBitsSiluSimplifiedFusedGraph(const Graph& graph) { + return CheckMatMulNBitsSiluFusedGraphImpl(graph, NormAnchorKind::kSimplified); +} + +Status CheckMatMulNBitsSiluSkipFusedGraph(const Graph& graph) { + return CheckMatMulNBitsSiluFusedGraphImpl(graph, NormAnchorKind::kSkipSimplified); +} + +Status CheckMatMulNBitsSiluSkipOutputPassthroughFusedGraph(const Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + if (OpCount(op_to_count, "com.microsoft.MatMulNBitsSiluMul") != 1 || + OpCount(op_to_count, "com.microsoft.MatMulNBits") != 0 || + OpCount(op_to_count, "SimplifiedLayerNormalization") != 0 || + OpCount(op_to_count, "com.microsoft.SkipSimplifiedLayerNormalization") != 0 || + OpCount(op_to_count, "Sigmoid") != 0 || + OpCount(op_to_count, "Mul") != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Unexpected operator counts after MatMulNBitsSiluFusion with skip output passthrough."); + } + + bool found_fused_node = false; + for (const auto& node : graph.Nodes()) { + if (node.OpType() != "MatMulNBitsSiluMul") { + continue; + } + + found_fused_node = true; + ORT_RETURN_IF_NOT(node.Domain() == kMSDomain, "Fused node must be in com.microsoft domain."); + ORT_RETURN_IF_NOT(node.GetExecutionProviderType() == kWebGpuExecutionProvider, + "Fused node must be assigned to WebGPU EP."); + ORT_RETURN_IF_NOT(node.InputDefs().size() == 9u, "Fused node must have 9 inputs."); + ORT_RETURN_IF_NOT(node.OutputDefs().size() == 2u, + "Fused node must expose Y and the passthrough residual output."); + const bool has_skip = node.InputDefs()[1] != nullptr && !node.InputDefs()[1]->Name().empty(); + const bool has_norm_scale = node.InputDefs()[2] != nullptr && !node.InputDefs()[2]->Name().empty(); + ORT_RETURN_IF_NOT(has_skip && has_norm_scale, + "Skip output passthrough should remain fused into MatMulNBitsSiluMul."); + ORT_RETURN_IF_NOT(node.OutputDefs()[1] != nullptr && !node.OutputDefs()[1]->Name().empty(), + "Expected fused node to preserve the residual passthrough output."); + } + + ORT_RETURN_IF_NOT(found_fused_node, "Expected a MatMulNBitsSiluMul node in the transformed graph."); + return Status::OK(); +} + +void BuildMatMulNBitsSiluWebGpuPatternImpl(ModelTestBuilder& builder, + NormAnchorKind norm_anchor_kind, + SkipOutputKind skip_output_kind = SkipOutputKind::kNone) { constexpr int64_t k = 16; constexpr int64_t n = 8; constexpr int64_t block_size = 16; @@ -79,6 +150,9 @@ void BuildMatMulNBitsSiluWebGpuPattern(ModelTestBuilder& builder) { NodeArg* up_scale = builder.MakeInitializer({n, 1}, MLFloat16(1.0f), MLFloat16(1.0f)); NodeArg* up_bias = builder.MakeInitializer({n}, MLFloat16(0.0f), MLFloat16(0.0f)); + NodeArg* normalized_input = norm_anchor_kind == NormAnchorKind::kNone + ? input + : builder.MakeIntermediate(std::vector{1, k}); NodeArg* gate_out = builder.MakeIntermediate(std::vector{1, n}); NodeArg* up_out = builder.MakeIntermediate(std::vector{1, n}); NodeArg* sigmoid_out = builder.MakeIntermediate(std::vector{1, n}); @@ -86,12 +160,43 @@ void BuildMatMulNBitsSiluWebGpuPattern(ModelTestBuilder& builder) { NodeArg* output = builder.MakeOutput(std::vector{1, n}); NodeAttributes matmul_attrs = MakeMatMulNBitsAttrs(k, n, block_size, bits, accuracy_level); - Node& gate_matmul = builder.AddNode("MatMulNBits", {input, gate_weight, gate_scale, optional_tensor, optional_tensor, gate_bias}, {gate_out}, kMSDomain, &matmul_attrs); - Node& up_matmul = builder.AddNode("MatMulNBits", {input, up_weight, up_scale, optional_tensor, optional_tensor, up_bias}, {up_out}, kMSDomain, &matmul_attrs); + Node* norm = nullptr; + if (norm_anchor_kind == NormAnchorKind::kSkipSimplified) { + NodeArg* skip_input = builder.MakeInput( + std::vector{1, k}, + std::vector(static_cast(k), MLFloat16(0.25f))); + NodeArg* norm_scale = builder.MakeInitializer({k}, MLFloat16(1.0f), MLFloat16(1.0f)); + NodeArg* optional_norm_output_1 = builder.MakeOptionalTensor(); + NodeArg* optional_norm_output_2 = builder.MakeOptionalTensor(); + std::vector norm_outputs{normalized_input}; + if (skip_output_kind == SkipOutputKind::kGraphOutput) { + NodeArg* residual_output = builder.MakeOutput(std::vector{1, k}); + norm_outputs.push_back(optional_norm_output_1); + norm_outputs.push_back(optional_norm_output_2); + norm_outputs.push_back(residual_output); + } + norm = &builder.AddNode("SkipSimplifiedLayerNormalization", {input, skip_input, norm_scale}, norm_outputs, + kMSDomain); + } else if (norm_anchor_kind == NormAnchorKind::kSimplified) { + NodeArg* norm_scale = builder.MakeInitializer({k}, MLFloat16(1.0f), MLFloat16(1.0f)); + norm = &builder.AddNode("SimplifiedLayerNormalization", {input, norm_scale}, {normalized_input}); + } + + Node& gate_matmul = builder.AddNode("MatMulNBits", + {normalized_input, gate_weight, gate_scale, optional_tensor, optional_tensor, + gate_bias}, + {gate_out}, kMSDomain, &matmul_attrs); + Node& up_matmul = builder.AddNode("MatMulNBits", + {normalized_input, up_weight, up_scale, optional_tensor, optional_tensor, + up_bias}, + {up_out}, kMSDomain, &matmul_attrs); Node& sigmoid = builder.AddNode("Sigmoid", {gate_out}, {sigmoid_out}); Node& silu_mul = builder.AddNode("Mul", {gate_out, sigmoid_out}, {silu_out}); Node& final_mul = builder.AddNode("Mul", {silu_out, up_out}, {output}); + if (norm != nullptr) { + SetWebGpuProvider(*norm); + } SetWebGpuProvider(gate_matmul); SetWebGpuProvider(up_matmul); SetWebGpuProvider(sigmoid); @@ -99,6 +204,22 @@ void BuildMatMulNBitsSiluWebGpuPattern(ModelTestBuilder& builder) { SetWebGpuProvider(final_mul); } +void BuildMatMulNBitsSiluWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsSiluWebGpuPatternImpl(builder, NormAnchorKind::kNone); +} + +void BuildMatMulNBitsSiluSimplifiedWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsSiluWebGpuPatternImpl(builder, NormAnchorKind::kSimplified); +} + +void BuildMatMulNBitsSiluSkipWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsSiluWebGpuPatternImpl(builder, NormAnchorKind::kSkipSimplified); +} + +void BuildMatMulNBitsSiluSkipOutputPassthroughWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsSiluWebGpuPatternImpl(builder, NormAnchorKind::kSkipSimplified, SkipOutputKind::kGraphOutput); +} + } // namespace TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionFusesWebGpuPattern) { @@ -113,6 +234,42 @@ TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionFusesWebGpuPattern) { CheckMatMulNBitsSiluFusedGraph)); } +TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionFusesSkipWebGpuPattern) { + ASSERT_STATUS_OK(TestGraphTransformer( + BuildMatMulNBitsSiluSkipWebGpuPattern, + 21, + *logger_, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + TransformerLevel::Level2, + 1, + nullptr, + CheckMatMulNBitsSiluSkipFusedGraph)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionFusesSkipWebGpuPatternWithResidualOutputPassthrough) { + ASSERT_STATUS_OK(TestGraphTransformer( + BuildMatMulNBitsSiluSkipOutputPassthroughWebGpuPattern, + 21, + *logger_, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + TransformerLevel::Level2, + 1, + nullptr, + CheckMatMulNBitsSiluSkipOutputPassthroughFusedGraph)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionFusesSimplifiedWebGpuPattern) { + ASSERT_STATUS_OK(TestGraphTransformer( + BuildMatMulNBitsSiluSimplifiedWebGpuPattern, + 21, + *logger_, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + TransformerLevel::Level2, + 1, + nullptr, + CheckMatMulNBitsSiluSimplifiedFusedGraph)); +} + TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionMatchesUnfusedWebGpuResults) { auto webgpu_ep = DefaultWebGpuExecutionProvider(); if (!webgpu_ep) { @@ -137,6 +294,78 @@ TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionMatchesUnfusedWebGpuResult std::move(webgpu_ep)); } +TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionMatchesUnfusedSkipWebGpuResults) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsSiluSkipFusedGraph(session.GetGraph())); + }; + + TransformerTester( + BuildMatMulNBitsSiluSkipWebGpuPattern, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + {}, + {}, + std::move(webgpu_ep)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionMatchesUnfusedSkipWebGpuResultsWithResidualOutputPassthrough) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsSiluSkipOutputPassthroughFusedGraph(session.GetGraph())); + }; + + TransformerTester( + BuildMatMulNBitsSiluSkipOutputPassthroughWebGpuPattern, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + {}, + {}, + std::move(webgpu_ep)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionMatchesUnfusedSimplifiedWebGpuResults) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsSiluSimplifiedFusedGraph(session.GetGraph())); + }; + + TransformerTester( + BuildMatMulNBitsSiluSimplifiedWebGpuPattern, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + {}, + {}, + std::move(webgpu_ep)); +} + #endif // !defined(DISABLE_CONTRIB_OPS) } // namespace test From 01671d9c4e01d111cb47e5e0d0ef82f61f2a178e Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Wed, 29 Apr 2026 18:02:20 -0700 Subject: [PATCH 08/26] Cleanup --- ...template => dp4a_matmul_mlp.wgsl.template} | 0 .../webgpu/quantization/matmul_nbits.cc | 40 +- .../quantization/matmul_nbits_common.cc | 106 ++ .../webgpu/quantization/matmul_nbits_common.h | 37 + ...tmul_nbits_silu.cc => matmul_nbits_mlp.cc} | 378 ++---- ...matmul_nbits_silu.h => matmul_nbits_mlp.h} | 15 +- ...emplate => matmul_nbits_mlp.wgsl.template} | 40 - ...tmul_nbits_mlp_wide_tile_m1.wgsl.template} | 0 ...l_nbits_qkv_sln.cc => matmul_nbits_qkv.cc} | 214 +--- ...mul_nbits_qkv_sln.h => matmul_nbits_qkv.h} | 6 +- ...emplate => matmul_nbits_qkv.wgsl.template} | 44 - .../webgpu/webgpu_contrib_kernels.cc | 8 +- .../core/graph/contrib_ops/contrib_defs.cc | 41 +- .../core/optimizer/graph_transformer_utils.cc | 32 +- ...u_fusion.cc => matmul_nbits_mlp_fusion.cc} | 106 +- ...ilu_fusion.h => matmul_nbits_mlp_fusion.h} | 6 +- ...n_fusion.cc => matmul_nbits_qkv_fusion.cc} | 29 +- ...sln_fusion.h => matmul_nbits_qkv_fusion.h} | 6 +- .../core/providers/webgpu/allocator.cc | 14 +- onnxruntime/core/providers/webgpu/allocator.h | 5 +- .../core/providers/webgpu/webgpu_context.h | 2 +- .../webgpu/webgpu_execution_provider.cc | 25 +- .../webgpu/webgpu_execution_provider.h | 1 + .../webgpu_matmul_nbits_decode.cc | 1021 +++++++---------- .../optimizer/graph_transform_utils_test.cc | 76 -- ...est.cc => matmul_nbits_mlp_fusion_test.cc} | 176 ++- ...est.cc => matmul_nbits_qkv_fusion_test.cc} | 116 +- 27 files changed, 1054 insertions(+), 1490 deletions(-) rename onnxruntime/contrib_ops/webgpu/quantization/{dp4a_matmul_silu_mul.wgsl.template => dp4a_matmul_mlp.wgsl.template} (100%) rename onnxruntime/contrib_ops/webgpu/quantization/{matmul_nbits_silu.cc => matmul_nbits_mlp.cc} (63%) rename onnxruntime/contrib_ops/webgpu/quantization/{matmul_nbits_silu.h => matmul_nbits_mlp.h} (67%) rename onnxruntime/contrib_ops/webgpu/quantization/{matmul_nbits_silu_mul.wgsl.template => matmul_nbits_mlp.wgsl.template} (93%) rename onnxruntime/contrib_ops/webgpu/quantization/{matmul_nbits_silu_mul_wide_tile_m1.wgsl.template => matmul_nbits_mlp_wide_tile_m1.wgsl.template} (100%) rename onnxruntime/contrib_ops/webgpu/quantization/{matmul_nbits_qkv_sln.cc => matmul_nbits_qkv.cc} (76%) rename onnxruntime/contrib_ops/webgpu/quantization/{matmul_nbits_qkv_sln.h => matmul_nbits_qkv.h} (78%) rename onnxruntime/contrib_ops/webgpu/quantization/{matmul_nbits_qkv_sln.wgsl.template => matmul_nbits_qkv.wgsl.template} (87%) rename onnxruntime/core/optimizer/{matmul_nbits_silu_fusion.cc => matmul_nbits_mlp_fusion.cc} (79%) rename onnxruntime/core/optimizer/{matmul_nbits_silu_fusion.h => matmul_nbits_mlp_fusion.h} (57%) rename onnxruntime/core/optimizer/{matmul_nbits_qkv_sln_fusion.cc => matmul_nbits_qkv_fusion.cc} (90%) rename onnxruntime/core/optimizer/{matmul_nbits_qkv_sln_fusion.h => matmul_nbits_qkv_fusion.h} (65%) rename onnxruntime/test/optimizer/{matmul_nbits_silu_fusion_test.cc => matmul_nbits_mlp_fusion_test.cc} (62%) rename onnxruntime/test/optimizer/{matmul_nbits_qkv_sln_fusion_test.cc => matmul_nbits_qkv_fusion_test.cc} (63%) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_silu_mul.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_mlp.wgsl.template similarity index 100% rename from onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_silu_mul.wgsl.template rename to onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_mlp.wgsl.template diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 0db99c816dc29..7a0df249a6eee 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -18,10 +18,6 @@ namespace onnxruntime { namespace contrib { namespace webgpu { -namespace { -constexpr unsigned int kMinMForTileOptimization = 4; -} // namespace - ONNX_OPERATOR_KERNEL_EX( MatMulNBits, kMSDomain, @@ -222,29 +218,43 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, uint32_t zero_blocks_per_col = (n_blocks_per_col + zp_elements_per_byte - 1) / zp_elements_per_byte * zp_elements_per_byte; #if !defined(__wasm__) - int32_t subgroup_matrix_config_index = -1; // apple|intel - Experimental dawn support for subgroup matrix matmul. - if ((M >= kMinMForTileOptimization && !has_weight_idx_indirect) && (context.AdapterInfo().vendor == std::string_view{"apple"} || context.AdapterInfo().vendor == std::string_view{"intel"}) && - CanApplySubgroupMatrixMatMulNBits(context, accuracy_level, block_size, batch_count, N, K, static_cast(nbits), y->DataType() == DataTypeImpl::GetType(), subgroup_matrix_config_index)) { + int32_t subgroup_matrix_config_index = -1; + if (WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(a, + K_op, + N_op, + block_size_op, + accuracy_level, + nbits, + context, + y, + has_weight_idx_indirect, + &subgroup_matrix_config_index)) { return ApplySubgroupMatrixMatMulNBits(a, b, scales, zero_points, bias, M, N, K, static_cast(nbits), zero_blocks_per_col, subgroup_matrix_config_index, context, y, weight_index, weight_index_indirect); } #endif // On FP32 only GPUs and Qualcomm GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M. // DP4A Q2 path now supports custom zero points via a 1024-entry LUT (4 zero-point sections × 256 byte values). - if (((M >= kMinMForTileOptimization && !has_weight_idx_indirect) || y->DataType() == DataTypeImpl::GetType() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && - CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)) { + if (WouldApplyDP4AMatMulNBitsInCurrentDispatch(a, + K_op, + N_op, + block_size_op, + accuracy_level, + context, + y, + has_weight_idx_indirect)) { return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, batch_count, M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast(nbits), context, y, weight_index, weight_index_indirect); } // WideTileProgram // This program is optimized for Block32 prefill using Tile16x128. - const bool use_wide_tile_program = !has_weight_idx_indirect && - block_size == 32 && - components_a == 4 && - components_b == 4 && - nbits != 2 && - M >= kMinMForTileOptimization; + const bool use_wide_tile_program = WouldApplyWideTileMatMulNBitsInCurrentDispatch(a, + K_op, + N_op, + block_size_op, + nbits, + has_weight_idx_indirect); if (use_wide_tile_program) { // Enforce output components to 1. diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc index b9eafeb43c7b6..2d08b159ab938 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc @@ -2,9 +2,16 @@ // Licensed under the MIT License. #include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" + #include + #include "core/common/common.h" +#include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" +#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" +#include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/webgpu_utils.h" +#include "core/framework/tensor_shape.h" namespace onnxruntime { namespace contrib { @@ -61,6 +68,105 @@ bool HasDP4ADeviceSupport(int context_id) { ctx.AdapterInfo().vendor != std::string_view{"apple"}; } +bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t accuracy_level, + int64_t nbits, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + bool has_weight_idx_indirect, + int32_t* subgroup_matrix_config_index) { + TensorShape b_shape({N_op, K_op}); + MatMulComputeHelper helper; + if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { + return false; + } + + const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size()); + const uint32_t M = onnxruntime::narrow(helper.M()); + const uint32_t N = onnxruntime::narrow(helper.N()); + const uint32_t K = onnxruntime::narrow(helper.K()); + const uint32_t block_size = onnxruntime::narrow(block_size_op); + +#if !defined(__wasm__) + int32_t local_subgroup_matrix_config_index = -1; + return (M >= kMinMForTileOptimization && !has_weight_idx_indirect) && + (context.AdapterInfo().vendor == std::string_view{"apple"} || + context.AdapterInfo().vendor == std::string_view{"intel"}) && + CanApplySubgroupMatrixMatMulNBits(context, + accuracy_level, + block_size, + batch_count, + N, + K, + static_cast(nbits), + y->DataType() == DataTypeImpl::GetType(), + subgroup_matrix_config_index != nullptr ? *subgroup_matrix_config_index : local_subgroup_matrix_config_index); +#endif + + return false; +} + +bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t accuracy_level, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + bool has_weight_idx_indirect) { + TensorShape b_shape({N_op, K_op}); + MatMulComputeHelper helper; + if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { + return false; + } + + const uint32_t M = onnxruntime::narrow(helper.M()); + const uint32_t N = onnxruntime::narrow(helper.N()); + const uint32_t K = onnxruntime::narrow(helper.K()); + const uint32_t block_size = onnxruntime::narrow(block_size_op); + const uint32_t components_a = GetMaxComponents(K); + + return ((M >= kMinMForTileOptimization && !has_weight_idx_indirect) || + y->DataType() == DataTypeImpl::GetType() || + context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && + CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a); +} + +bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t nbits, + bool has_weight_idx_indirect) { + if (has_weight_idx_indirect) { + return false; + } + + TensorShape b_shape({N_op, K_op}); + MatMulComputeHelper helper; + if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { + return false; + } + + const uint32_t M = onnxruntime::narrow(helper.M()); + const uint32_t K = onnxruntime::narrow(helper.K()); + const uint32_t block_size = onnxruntime::narrow(block_size_op); + const uint32_t components_a = GetMaxComponents(K); + const uint32_t block_size_per_col = block_size; + const uint32_t blob_size = (block_size_per_col / 8) * static_cast(nbits); + const uint32_t blob_size_in_words = blob_size / 4; + const uint32_t components_b = GetMaxComponents(blob_size_in_words); + + return block_size == 32 && + components_a == 4 && + components_b == 4 && + nbits != 2 && + M >= kMinMForTileOptimization; +} + } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h index 3db7c722b11eb..fbb6c8cfaa3c9 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h @@ -6,10 +6,20 @@ #include #include +namespace onnxruntime { +class Tensor; + +namespace webgpu { +class ComputeContext; +} // namespace webgpu +} + namespace onnxruntime { namespace contrib { namespace webgpu { +inline constexpr uint32_t kMinMForTileOptimization = 4u; + /** * Generates WebGPU shader code for reading zero points in quantized matrix multiplication * @@ -26,6 +36,33 @@ std::string GenerateZeroPointReadingCode(uint32_t nbits, bool has_zero_points, /// \p context_id is the WebGpuContext slot (0 for the default context). bool HasDP4ADeviceSupport(int context_id = 0); +bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t accuracy_level, + int64_t nbits, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + bool has_weight_idx_indirect = false, + int32_t* subgroup_matrix_config_index = nullptr); + +bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t accuracy_level, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + bool has_weight_idx_indirect = false); + +bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t nbits, + bool has_weight_idx_indirect = false); + } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc similarity index 63% rename from onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc rename to onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc index 9b4f456c27894..4dbb4d4e48f56 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "contrib_ops/webgpu/quantization/matmul_nbits_silu.h" +#include "contrib_ops/webgpu/quantization/matmul_nbits_mlp.h" #include "contrib_ops/webgpu/quantization/matmul_nbits.h" #include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" @@ -21,104 +21,10 @@ namespace webgpu { namespace { -constexpr unsigned int kMinMForTileOptimization = 4; - constexpr uint32_t kFusedDecodeFastPathBits = 4u; constexpr uint32_t kFusedDecodeFastPathBlockSize = 32u; constexpr float kSkipSimplifiedLayerNormEpsilon = 1e-05f; -bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(const Tensor* a, - int64_t K_op, - int64_t N_op, - int64_t block_size_op, - int64_t accuracy_level, - int64_t nbits, - onnxruntime::webgpu::ComputeContext& context, - Tensor* y) { - TensorShape b_shape({N_op, K_op}); - MatMulComputeHelper helper; - if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { - return false; - } - - const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size()); - const uint32_t M = onnxruntime::narrow(helper.M()); - const uint32_t N = onnxruntime::narrow(helper.N()); - const uint32_t K = onnxruntime::narrow(helper.K()); - const uint32_t block_size = onnxruntime::narrow(block_size_op); - -#if !defined(__wasm__) - int32_t subgroup_matrix_config_index = -1; - return (M >= kMinMForTileOptimization) && - (context.AdapterInfo().vendor == std::string_view{"apple"} || - context.AdapterInfo().vendor == std::string_view{"intel"}) && - CanApplySubgroupMatrixMatMulNBits(context, - accuracy_level, - block_size, - batch_count, - N, - K, - static_cast(nbits), - y->DataType() == DataTypeImpl::GetType(), - subgroup_matrix_config_index); -#endif - - return false; -} - -bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(const Tensor* a, - int64_t K_op, - int64_t N_op, - int64_t block_size_op, - int64_t nbits) { - TensorShape b_shape({N_op, K_op}); - MatMulComputeHelper helper; - if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { - return false; - } - - const uint32_t M = onnxruntime::narrow(helper.M()); - const uint32_t K = onnxruntime::narrow(helper.K()); - const uint32_t block_size = onnxruntime::narrow(block_size_op); - - const uint32_t components_a = GetMaxComponents(K); - const uint32_t block_size_per_col = block_size; - const uint32_t blob_size = (block_size_per_col / 8) * static_cast(nbits); - const uint32_t blob_size_in_words = blob_size / 4; - const uint32_t components_b = GetMaxComponents(blob_size_in_words); - - return block_size == 32 && - components_a == 4 && - components_b == 4 && - nbits != 2 && - M >= kMinMForTileOptimization; -} - -bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(const Tensor* a, - int64_t K_op, - int64_t N_op, - int64_t block_size_op, - int64_t accuracy_level, - onnxruntime::webgpu::ComputeContext& context, - Tensor* y) { - TensorShape b_shape({N_op, K_op}); - MatMulComputeHelper helper; - if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { - return false; - } - - const uint32_t M = onnxruntime::narrow(helper.M()); - const uint32_t N = onnxruntime::narrow(helper.N()); - const uint32_t K = onnxruntime::narrow(helper.K()); - const uint32_t block_size = onnxruntime::narrow(block_size_op); - const uint32_t components_a = GetMaxComponents(K); - - return ((M >= kMinMForTileOptimization) || - y->DataType() == DataTypeImpl::GetType() || - context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && - CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a); -} - TensorShape GetOverrideShape(const TensorShape& shape, int components) { return TensorShape{shape.Size() / components}; } @@ -234,20 +140,18 @@ Status ApplyUnfusedSiluMul(const Tensor* a, onnxruntime::webgpu::ComputeContext& context, Tensor* y); -class MatMulNBitsSiluMulDecodeProgram final : public Program { +class MatMulNBitsMlpDecodeProgram final : public Program { public: - MatMulNBitsSiluMulDecodeProgram(uint32_t tile_size, - bool has_gate_bias, - bool has_up_bias, - bool has_norm_input, - bool has_skip_input, - bool has_skip_output, - bool single_scale_weights, - uint32_t tile_size_k_vec, - uint32_t k_unroll_tiles, - bool has_full_n_tiles, - bool has_full_k_tiles) - : Program{"MatMulNBitsSiluMulDecode"}, + MatMulNBitsMlpDecodeProgram(uint32_t tile_size, + bool has_gate_bias, + bool has_up_bias, + bool has_norm_input, + bool has_skip_input, + bool has_skip_output, + bool single_scale_weights, + uint32_t tile_size_k_vec, + uint32_t k_unroll_tiles) + : Program{"MatMulNBitsMlpDecode"}, tile_size_(tile_size), has_gate_bias_(has_gate_bias), has_up_bias_(has_up_bias), @@ -256,9 +160,7 @@ class MatMulNBitsSiluMulDecodeProgram final : public Program { - public: - MatMulNBitsSiluMulWideTileM1Program(bool has_gate_bias, - bool has_up_bias, - uint32_t outputs_per_thread) - : Program{"MatMulNBitsSiluMulWideTileM1Decode"}, - has_gate_bias_(has_gate_bias), - has_up_bias_(has_up_bias), - outputs_per_thread_(outputs_per_thread) {} - - Status GenerateShaderCode(ShaderHelper& shader) const override { - const auto& a = shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - const auto& gate_b = shader.AddInput("gate_b", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - const auto& gate_scales_b = shader.AddInput("gate_scales_b", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - const auto& up_b = shader.AddInput("up_b", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - const auto& up_scales_b = shader.AddInput("up_scales_b", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - if (has_gate_bias_) { - shader.AddInput("gate_bias", ShaderUsage::UseUniform); - } - if (has_up_bias_) { - shader.AddInput("up_bias", ShaderUsage::UseUniform); - } - const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - - return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_silu_mul_wide_tile_m1.wgsl.template", - WGSL_TEMPLATE_PARAMETER(has_gate_bias, has_gate_bias_), - WGSL_TEMPLATE_PARAMETER(has_up_bias, has_up_bias_), - WGSL_TEMPLATE_PARAMETER(outputs_per_thread, outputs_per_thread_), - WGSL_TEMPLATE_VARIABLE(a, a), - WGSL_TEMPLATE_VARIABLE(gate_b, gate_b), - WGSL_TEMPLATE_VARIABLE(gate_scales_b, gate_scales_b), - WGSL_TEMPLATE_VARIABLE(output, output), - WGSL_TEMPLATE_VARIABLE(up_b, up_b), - WGSL_TEMPLATE_VARIABLE(up_scales_b, up_scales_b)); - } - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( - {"batch_count", ProgramUniformVariableDataType::Uint32}, - {"N", ProgramUniformVariableDataType::Uint32}, - {"K_of_a", ProgramUniformVariableDataType::Uint32}, - {"K_of_b", ProgramUniformVariableDataType::Uint32}, - {"n_blocks_per_col", ProgramUniformVariableDataType::Uint32}); - - private: - bool has_gate_bias_; - bool has_up_bias_; - uint32_t outputs_per_thread_; }; -class MatMulNBitsSiluMulProgram final : public Program { +class MatMulNBitsMlpProgram final : public Program { public: - MatMulNBitsSiluMulProgram() : Program{"MatMulNBitsSiluMul"} {} + MatMulNBitsMlpProgram() : Program{"MatMulNBitsMlp"} {} Status GenerateShaderCode(ShaderHelper& shader) const override { const auto& gate = shader.AddInput("gate", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); @@ -471,20 +354,20 @@ class MatMulNBitsSiluMulProgram final : public ProgramShape(), b_shape, false, true)); @@ -498,7 +381,7 @@ Status ApplyUnfusedSiluMul(const Tensor* a, const uint32_t data_size = onnxruntime::narrow(y->Shape().Size()); const uint32_t vec_size = (data_size + 3u) / 4u; - MatMulNBitsSiluMulProgram program; + MatMulNBitsMlpProgram program; program .AddInputs({{&gate_output, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, 4}, {&up_output, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, 4}}) @@ -512,16 +395,16 @@ Status ApplyUnfusedSiluMul(const Tensor* a, } // namespace ONNX_OPERATOR_KERNEL_EX( - MatMulNBitsSiluMul, + MatMulNBitsMlp, kMSDomain, 1, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T1", WebGpuSupportedFloatTypes()) .TypeConstraint("T2", DataTypeImpl::GetTensorType()), - MatMulNBitsSiluMul); + MatMulNBitsMlp); -Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + Status MatMulNBitsMlp::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* a = context.Input(0); const Tensor* skip = context.Input(1); const Tensor* norm_scale = context.Input(2); @@ -533,7 +416,7 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& const Tensor* up_bias = context.Input(8); ORT_ENFORCE(skip == nullptr || norm_scale != nullptr, - "MatMulNBitsSiluMul requires norm_scale when skip is present."); + "MatMulNBitsMlp requires norm_scale when skip is present."); MatMulComputeHelper helper; TensorShape b_shape({N_, K_}); @@ -556,7 +439,9 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& const uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; Tensor* y = context.Output(0, output_shape); - Tensor* input_skip_bias_sum = skip != nullptr ? context.Output(1, a->Shape()) : nullptr; + Tensor* input_skip_bias_sum = (skip != nullptr && context.OutputCount() > 1) + ? context.Output(1, a->Shape()) + : nullptr; const uint32_t data_size = onnxruntime::narrow(y->Shape().Size()); if (data_size == 0) { return Status::OK(); @@ -566,6 +451,9 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& ORT_ENFORCE(norm_scale->Shape().Size() == K_, "norm_scale must have shape [K]."); } + const bool has_skip_input = skip != nullptr; + const bool has_skip_output = input_skip_bias_sum != nullptr; + const bool is_decode_fast_path_candidate = M == 1 && bits_ == kFusedDecodeFastPathBits && @@ -596,9 +484,6 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& block_size_, bits_); - // The experimental wide M==1 fused path regressed badly on NVIDIA decode shapes. - // Keep the implementation around for future work, but do not dispatch to it. - const bool can_use_decode_fast_path = is_decode_fast_path_candidate && !would_use_subgroup_unfused && @@ -606,23 +491,18 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& !would_use_wide_tile_unfused; if (can_use_decode_fast_path) { - //ORT_ENFORCE(false, "The experimental wide M==1 fused path regressed badly on NVIDIA decode shapes. Keep the implementation around for future work, but do not dispatch to it."); ORT_ENFORCE(bits_ == kFusedDecodeFastPathBits, - "MatMulNBitsSiluMulDecodeProgram is specialized for 4-bit weights only."); + "MatMulNBitsMlpDecodeProgram is specialized for 4-bit weights only."); ORT_ENFORCE(block_size == kFusedDecodeFastPathBlockSize, - "MatMulNBitsSiluMulDecodeProgram is specialized for block_size=32 only."); + "MatMulNBitsMlpDecodeProgram is specialized for block_size=32 only."); const bool has_gate_bias = gate_bias != nullptr; const bool has_up_bias = up_bias != nullptr; - const bool has_skip_input = skip != nullptr; - const bool has_skip_output = input_skip_bias_sum != nullptr; uint32_t workgroup_size = 128; uint32_t tile_size = 8; uint32_t tile_size_k_vec = (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u; - // For the smallest decode-like M==1 case, reduce K-split width and workgroup size - // so the generic fused kernel spends less time on reduction and barriers. if (context.AdapterInfo().vendor != std::string_view{"intel"} && N <= 2048) { workgroup_size = 64; tile_size = 4; @@ -631,12 +511,10 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& const uint32_t elements_in_value_b = components_b * (32u / onnxruntime::narrow(bits_)); const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b; - const bool has_full_n_tiles = (N % tile_size) == 0; - const bool has_full_k_tiles = (K % tile_size_k) == 0; const uint32_t k_tile_iterations = K / tile_size_k; uint32_t k_unroll_tiles = 1; - if (has_full_k_tiles) { + if ((K % tile_size_k) == 0) { if (k_tile_iterations >= 8 && N <= 2048 && context.AdapterInfo().vendor != std::string_view{"intel"}) { k_unroll_tiles = 4; } else if (k_tile_iterations >= 4) { @@ -646,17 +524,15 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& const uint32_t num_N_tile = CeilDiv(N, tile_size); - MatMulNBitsSiluMulDecodeProgram program{tile_size, - has_gate_bias, - has_up_bias, - has_norm_input, - has_skip_input, - has_skip_output, - single_scale_weights, - tile_size_k_vec, - k_unroll_tiles, - has_full_n_tiles, - has_full_k_tiles}; + MatMulNBitsMlpDecodeProgram program{tile_size, + has_gate_bias, + has_up_bias, + has_norm_input, + has_skip_input, + has_skip_output, + single_scale_weights, + tile_size_k_vec, + k_unroll_tiles}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize(num_N_tile, 1, batch_count); program.AddInput({a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); @@ -690,8 +566,6 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& has_skip_output, tile_size_k_vec, k_unroll_tiles, - has_full_n_tiles, - has_full_k_tiles, "decode_4bit"); if (has_skip_output) { program.AddOutput({input_skip_bias_sum, @@ -712,7 +586,7 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape()); ORT_RETURN_IF_ERROR(ApplySkipSimplifiedLayerNorm(a, skip, norm_scale, kSkipSimplifiedLayerNormEpsilon, context, &normalized_a, input_skip_bias_sum)); - return ApplyUnfusedSiluMul(&normalized_a, + return ApplyUnfusedMlp(&normalized_a, gate_b, gate_scales, gate_bias, @@ -731,7 +605,7 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& if (norm_scale != nullptr) { Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape()); ORT_RETURN_IF_ERROR(ApplySimplifiedLayerNorm(a, norm_scale, kSkipSimplifiedLayerNormEpsilon, context, &normalized_a)); - return ApplyUnfusedSiluMul(&normalized_a, + return ApplyUnfusedMlp(&normalized_a, gate_b, gate_scales, gate_bias, @@ -747,7 +621,7 @@ Status MatMulNBitsSiluMul::ComputeInternal(onnxruntime::webgpu::ComputeContext& y); } - return ApplyUnfusedSiluMul(a, + return ApplyUnfusedMlp(a, gate_b, gate_scales, gate_bias, diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h similarity index 67% rename from onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.h rename to onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h index 476a76c72fa34..52333d293dce1 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h @@ -3,6 +3,8 @@ #pragma once +#include + #include "core/providers/webgpu/webgpu_kernel.h" namespace onnxruntime { @@ -12,16 +14,20 @@ namespace webgpu { using namespace onnxruntime::webgpu; using onnxruntime::webgpu::ComputeContext; -class MatMulNBitsSiluMul final : public WebGpuKernel { +class MatMulNBitsMlp final : public WebGpuKernel { public: - explicit MatMulNBitsSiluMul(const OpKernelInfo& info) : WebGpuKernel(info) { + explicit MatMulNBitsMlp(const OpKernelInfo& info) : WebGpuKernel(info) { K_ = info.GetAttr("K"); N_ = info.GetAttr("N"); block_size_ = info.GetAttr("block_size"); bits_ = info.GetAttr("bits"); accuracy_level_ = info.GetAttrOrDefault("accuracy_level", 4); + ORT_ENFORCE(info.GetAttr("activation", &activation_).IsOK(), + "MatMulNBitsMlp requires the 'activation' attribute."); ORT_ENFORCE(bits_ == 4 || bits_ == 8 || bits_ == 2, - "Only 4b/8b/2b quantization is supported for MatMulNBitsSiluMul op."); + "Only 4b/8b/2b quantization is supported for MatMulNBitsMlp op."); + ORT_ENFORCE(activation_ == "silu", + "MatMulNBitsMlp currently only supports activation='silu'."); } Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; @@ -32,8 +38,9 @@ class MatMulNBitsSiluMul final : public WebGpuKernel { int64_t block_size_; int64_t accuracy_level_; int64_t bits_; + std::string activation_; }; } // namespace webgpu } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template similarity index 93% rename from onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul.wgsl.template rename to onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template index 190d06ba958e0..a183962b230da 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template @@ -5,8 +5,6 @@ #param component_a #param component_b #param elements_in_value_b -#param has_full_k_tiles -#param has_full_n_tiles #param single_scale_weights #param sub_tile_count #param has_norm_input @@ -43,19 +41,6 @@ fn loadSHMA(batch: u32, b_global_base: u32, kidx: u32, col: u32, inv_std: f32) { let k_offset = kidx / component_a + col; let input_offset = batch * uniforms.K_of_a + k_offset; -#if has_full_k_tiles - let merged_value = load_merged_input(input_offset); -#if has_skip_output - if (b_global_base == 0u) { - input_skip_bias_sum.setByOffset(input_offset, input_skip_bias_sum_value_t(merged_value)); - } -#endif -#if has_norm_input - tile_A[col] = merged_value * input_a_value_t(input_a_element_t(inv_std)) * norm_scale.getByOffset(k_offset); -#else - tile_A[col] = merged_value; -#endif -#else if (k_offset < uniforms.K_of_a) { let merged_value = load_merged_input(input_offset); #if has_skip_output @@ -71,7 +56,6 @@ fn loadSHMA(batch: u32, b_global_base: u32, kidx: u32, col: u32, inv_std: f32) } else { tile_A[col] = input_a_value_t(0); } -#endif } fn compute_gate_up_sums(b_global: u32, kidx: u32, idx: u32, k_offset: u32) -> vec2 { @@ -157,31 +141,11 @@ fn process_k_tile(batch: u32, b_global_base: u32, thread_idx: u32, idx: u32, idy { let b_global = b_global_base + local_row_offset + idy; let k_offset = kidx / elements_in_value_b + idx; -#if has_full_n_tiles -#if !has_full_k_tiles - if (k_offset < uniforms.K_of_b) { -#endif -#else -#if has_full_k_tiles - if (b_global < uniforms.N) { -#else if (b_global < uniforms.N && k_offset < uniforms.K_of_b) { -#endif -#endif let sums = compute_gate_up_sums(b_global, kidx, idx, k_offset); gate_inter_results[local_row_offset + idy][idx] += sums[0]; up_inter_results[local_row_offset + idy][idx] += sums[1]; -#if has_full_n_tiles -#if !has_full_k_tiles - } -#endif -#else -#if has_full_k_tiles } -#else - } -#endif -#endif } workgroupBarrier(); } @@ -280,11 +244,7 @@ $MAIN { } let b_global = b_global_base + local_idx; let output_idx = batch * uniforms.N + b_global; -#if has_full_n_tiles - { -#else if (b_global < uniforms.N) { -#endif #if has_gate_bias gate_output_value += gate_bias[b_global]; #endif diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul_wide_tile_m1.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp_wide_tile_m1.wgsl.template similarity index 100% rename from onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_silu_mul_wide_tile_m1.wgsl.template rename to onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp_wide_tile_m1.wgsl.template diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc similarity index 76% rename from onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.cc rename to onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc index 0902dcb1617fb..3ca9b1280d011 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc @@ -1,7 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.h" +#include "contrib_ops/webgpu/quantization/matmul_nbits_qkv.h" + +#include #include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" #include "contrib_ops/webgpu/quantization/matmul_nbits.h" @@ -21,99 +23,6 @@ namespace webgpu { namespace { -constexpr unsigned int kMinMForTileOptimization = 4; - -bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(const Tensor* a, - int64_t K_op, - int64_t N_op, - int64_t block_size_op, - int64_t accuracy_level, - int64_t nbits, - onnxruntime::webgpu::ComputeContext& context, - Tensor* y) { - TensorShape b_shape({N_op, K_op}); - MatMulComputeHelper helper; - if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { - return false; - } - - const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size()); - const uint32_t M = onnxruntime::narrow(helper.M()); - const uint32_t N = onnxruntime::narrow(helper.N()); - const uint32_t K = onnxruntime::narrow(helper.K()); - const uint32_t block_size = onnxruntime::narrow(block_size_op); - -#if !defined(__wasm__) - int32_t subgroup_matrix_config_index = -1; - return (M >= kMinMForTileOptimization) && - (context.AdapterInfo().vendor == std::string_view{"apple"} || - context.AdapterInfo().vendor == std::string_view{"intel"}) && - CanApplySubgroupMatrixMatMulNBits(context, - accuracy_level, - block_size, - batch_count, - N, - K, - static_cast(nbits), - y->DataType() == DataTypeImpl::GetType(), - subgroup_matrix_config_index); -#endif - - return false; -} - -bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(const Tensor* a, - int64_t K_op, - int64_t N_op, - int64_t block_size_op, - int64_t accuracy_level, - onnxruntime::webgpu::ComputeContext& context, - Tensor* y) { - TensorShape b_shape({N_op, K_op}); - MatMulComputeHelper helper; - if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { - return false; - } - - const uint32_t M = onnxruntime::narrow(helper.M()); - const uint32_t N = onnxruntime::narrow(helper.N()); - const uint32_t K = onnxruntime::narrow(helper.K()); - const uint32_t block_size = onnxruntime::narrow(block_size_op); - const uint32_t components_a = GetMaxComponents(K); - - return ((M >= kMinMForTileOptimization) || - y->DataType() == DataTypeImpl::GetType() || - context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && - CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a); -} - -bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(const Tensor* a, - int64_t K_op, - int64_t N_op, - int64_t block_size_op, - int64_t nbits) { - TensorShape b_shape({N_op, K_op}); - MatMulComputeHelper helper; - if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { - return false; - } - - const uint32_t M = onnxruntime::narrow(helper.M()); - const uint32_t K = onnxruntime::narrow(helper.K()); - const uint32_t block_size = onnxruntime::narrow(block_size_op); - const uint32_t components_a = GetMaxComponents(K); - const uint32_t block_size_per_col = block_size; - const uint32_t blob_size = (block_size_per_col / 8) * static_cast(nbits); - const uint32_t blob_size_in_words = blob_size / 4; - const uint32_t components_b = GetMaxComponents(blob_size_in_words); - - return block_size == 32 && - components_a == 4 && - components_b == 4 && - nbits != 2 && - M >= kMinMForTileOptimization; -} - TensorShape GetOverrideShape(const TensorShape& shape, int components) { return TensorShape{shape.Size() / components}; } @@ -275,26 +184,20 @@ Status ApplyUnfusedQKVSkipSimplifiedLayerNorm(const Tensor* a, return Status::OK(); } -class MatMulNBitsQKVSimplifiedLayerNormDecodeProgram final - : public Program { +class MatMulNBitsQkvDecodeProgram final + : public Program { public: - MatMulNBitsQKVSimplifiedLayerNormDecodeProgram(uint32_t tile_size, - bool single_scale_weights, - uint32_t tile_size_k_vec, - uint32_t k_unroll_tiles, - bool has_full_q_tiles, - bool has_full_kv_tiles, - bool has_full_k_tiles, - bool has_skip_input, - bool has_skip_output) - : Program{"MatMulNBitsQKVSimplifiedLayerNormDecode"}, + MatMulNBitsQkvDecodeProgram(uint32_t tile_size, + bool single_scale_weights, + uint32_t tile_size_k_vec, + uint32_t k_unroll_tiles, + bool has_skip_input, + bool has_skip_output) + : Program{"MatMulNBitsQkvDecode"}, tile_size_(tile_size), single_scale_weights_(single_scale_weights), tile_size_k_vec_(tile_size_k_vec), k_unroll_tiles_(k_unroll_tiles), - has_full_q_tiles_(has_full_q_tiles), - has_full_kv_tiles_(has_full_kv_tiles), - has_full_k_tiles_(has_full_k_tiles), has_skip_input_(has_skip_input), has_skip_output_(has_skip_output) {} @@ -308,10 +211,18 @@ class MatMulNBitsQKVSimplifiedLayerNormDecodeProgram final const auto& k_scales_b = shader.AddInput("k_scales_b"); const auto& v_b = shader.AddInput("v_b"); const auto& v_scales_b = shader.AddInput("v_scales_b"); - const auto& q_output = shader.AddOutput("q_output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - const auto& k_output = shader.AddOutput("k_output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - const auto& v_output = shader.AddOutput("v_output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& q_output = shader.AddOutput("q_output", + ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias); + const auto& k_output = shader.AddOutput("k_output", + ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias); + const auto& v_output = shader.AddOutput("v_output", + ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias); const auto* input_skip_bias_sum = has_skip_output_ ? &shader.AddOutput("input_skip_bias_sum", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias) : nullptr; + const auto& skip_var = skip != nullptr ? *skip : a; + const auto& input_skip_bias_sum_var = input_skip_bias_sum != nullptr ? *input_skip_bias_sum : q_output; const uint32_t components_a = a.NumComponents(); const uint32_t components_b = q_b.NumComponents() / 4; @@ -323,14 +234,11 @@ class MatMulNBitsQKVSimplifiedLayerNormDecodeProgram final if (skip != nullptr) { if (input_skip_bias_sum != nullptr) { - return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_qkv_sln.wgsl.template", + return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_qkv.wgsl.template", WGSL_TEMPLATE_PARAMETER(a_length_per_tile, a_length_per_tile), WGSL_TEMPLATE_PARAMETER(component_a, components_a), WGSL_TEMPLATE_PARAMETER(component_b, components_b), WGSL_TEMPLATE_PARAMETER(elements_in_value_b, elements_in_value_b), - WGSL_TEMPLATE_PARAMETER(has_full_k_tiles, has_full_k_tiles_), - WGSL_TEMPLATE_PARAMETER(has_full_kv_tiles, has_full_kv_tiles_), - WGSL_TEMPLATE_PARAMETER(has_full_q_tiles, has_full_q_tiles_), WGSL_TEMPLATE_PARAMETER(has_skip_input, has_skip_input_), WGSL_TEMPLATE_PARAMETER(has_skip_output, has_skip_output_), WGSL_TEMPLATE_PARAMETER(k_unroll_tiles, k_unroll_tiles_), @@ -340,7 +248,7 @@ class MatMulNBitsQKVSimplifiedLayerNormDecodeProgram final WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k), WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), WGSL_TEMPLATE_VARIABLE(a, a), - WGSL_TEMPLATE_VARIABLE(input_skip_bias_sum, *input_skip_bias_sum), + WGSL_TEMPLATE_VARIABLE(input_skip_bias_sum, input_skip_bias_sum_var), WGSL_TEMPLATE_VARIABLE(k_b, k_b), WGSL_TEMPLATE_VARIABLE(k_output, k_output), WGSL_TEMPLATE_VARIABLE(k_scales_b, k_scales_b), @@ -348,20 +256,17 @@ class MatMulNBitsQKVSimplifiedLayerNormDecodeProgram final WGSL_TEMPLATE_VARIABLE(q_b, q_b), WGSL_TEMPLATE_VARIABLE(q_output, q_output), WGSL_TEMPLATE_VARIABLE(q_scales_b, q_scales_b), - WGSL_TEMPLATE_VARIABLE(skip, *skip), + WGSL_TEMPLATE_VARIABLE(skip, skip_var), WGSL_TEMPLATE_VARIABLE(v_b, v_b), WGSL_TEMPLATE_VARIABLE(v_output, v_output), WGSL_TEMPLATE_VARIABLE(v_scales_b, v_scales_b)); } - return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_qkv_sln.wgsl.template", + return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_qkv.wgsl.template", WGSL_TEMPLATE_PARAMETER(a_length_per_tile, a_length_per_tile), WGSL_TEMPLATE_PARAMETER(component_a, components_a), WGSL_TEMPLATE_PARAMETER(component_b, components_b), WGSL_TEMPLATE_PARAMETER(elements_in_value_b, elements_in_value_b), - WGSL_TEMPLATE_PARAMETER(has_full_k_tiles, has_full_k_tiles_), - WGSL_TEMPLATE_PARAMETER(has_full_kv_tiles, has_full_kv_tiles_), - WGSL_TEMPLATE_PARAMETER(has_full_q_tiles, has_full_q_tiles_), WGSL_TEMPLATE_PARAMETER(has_skip_input, has_skip_input_), WGSL_TEMPLATE_PARAMETER(has_skip_output, has_skip_output_), WGSL_TEMPLATE_PARAMETER(k_unroll_tiles, k_unroll_tiles_), @@ -371,6 +276,7 @@ class MatMulNBitsQKVSimplifiedLayerNormDecodeProgram final WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k), WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), WGSL_TEMPLATE_VARIABLE(a, a), + WGSL_TEMPLATE_VARIABLE(input_skip_bias_sum, input_skip_bias_sum_var), WGSL_TEMPLATE_VARIABLE(k_b, k_b), WGSL_TEMPLATE_VARIABLE(k_output, k_output), WGSL_TEMPLATE_VARIABLE(k_scales_b, k_scales_b), @@ -378,20 +284,17 @@ class MatMulNBitsQKVSimplifiedLayerNormDecodeProgram final WGSL_TEMPLATE_VARIABLE(q_b, q_b), WGSL_TEMPLATE_VARIABLE(q_output, q_output), WGSL_TEMPLATE_VARIABLE(q_scales_b, q_scales_b), - WGSL_TEMPLATE_VARIABLE(skip, *skip), + WGSL_TEMPLATE_VARIABLE(skip, skip_var), WGSL_TEMPLATE_VARIABLE(v_b, v_b), WGSL_TEMPLATE_VARIABLE(v_output, v_output), WGSL_TEMPLATE_VARIABLE(v_scales_b, v_scales_b)); } - return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_qkv_sln.wgsl.template", + return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_qkv.wgsl.template", WGSL_TEMPLATE_PARAMETER(a_length_per_tile, a_length_per_tile), WGSL_TEMPLATE_PARAMETER(component_a, components_a), WGSL_TEMPLATE_PARAMETER(component_b, components_b), WGSL_TEMPLATE_PARAMETER(elements_in_value_b, elements_in_value_b), - WGSL_TEMPLATE_PARAMETER(has_full_k_tiles, has_full_k_tiles_), - WGSL_TEMPLATE_PARAMETER(has_full_kv_tiles, has_full_kv_tiles_), - WGSL_TEMPLATE_PARAMETER(has_full_q_tiles, has_full_q_tiles_), WGSL_TEMPLATE_PARAMETER(has_skip_input, has_skip_input_), WGSL_TEMPLATE_PARAMETER(has_skip_output, has_skip_output_), WGSL_TEMPLATE_PARAMETER(k_unroll_tiles, k_unroll_tiles_), @@ -401,6 +304,7 @@ class MatMulNBitsQKVSimplifiedLayerNormDecodeProgram final WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k), WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), WGSL_TEMPLATE_VARIABLE(a, a), + WGSL_TEMPLATE_VARIABLE(input_skip_bias_sum, input_skip_bias_sum_var), WGSL_TEMPLATE_VARIABLE(k_b, k_b), WGSL_TEMPLATE_VARIABLE(k_output, k_output), WGSL_TEMPLATE_VARIABLE(k_scales_b, k_scales_b), @@ -408,6 +312,7 @@ class MatMulNBitsQKVSimplifiedLayerNormDecodeProgram final WGSL_TEMPLATE_VARIABLE(q_b, q_b), WGSL_TEMPLATE_VARIABLE(q_output, q_output), WGSL_TEMPLATE_VARIABLE(q_scales_b, q_scales_b), + WGSL_TEMPLATE_VARIABLE(skip, skip_var), WGSL_TEMPLATE_VARIABLE(v_b, v_b), WGSL_TEMPLATE_VARIABLE(v_output, v_output), WGSL_TEMPLATE_VARIABLE(v_scales_b, v_scales_b)); @@ -431,9 +336,6 @@ class MatMulNBitsQKVSimplifiedLayerNormDecodeProgram final bool single_scale_weights_; uint32_t tile_size_k_vec_; uint32_t k_unroll_tiles_; - bool has_full_q_tiles_; - bool has_full_kv_tiles_; - bool has_full_k_tiles_; bool has_skip_input_; bool has_skip_output_; }; @@ -441,16 +343,16 @@ class MatMulNBitsQKVSimplifiedLayerNormDecodeProgram final } // namespace ONNX_OPERATOR_KERNEL_EX( - MatMulNBitsQKVSimplifiedLayerNorm, + MatMulNBitsQkv, kMSDomain, 1, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T1", WebGpuSupportedFloatTypes()) .TypeConstraint("T2", DataTypeImpl::GetTensorType()), - MatMulNBitsQKVSimplifiedLayerNorm); + MatMulNBitsQkv); -Status MatMulNBitsQKVSimplifiedLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { +Status MatMulNBitsQkv::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* a = context.Input(0); const Tensor* skip = context.Input(1); const Tensor* norm_scale = context.Input(2); @@ -461,8 +363,8 @@ Status MatMulNBitsQKVSimplifiedLayerNorm::ComputeInternal(onnxruntime::webgpu::C const Tensor* v_b = context.Input(7); const Tensor* v_scales = context.Input(8); - ORT_ENFORCE(bits_ == 4, "MatMulNBitsQKVSimplifiedLayerNorm currently supports 4-bit weights only."); - ORT_ENFORCE(block_size_ == 32, "MatMulNBitsQKVSimplifiedLayerNorm currently supports block_size=32 only."); + ORT_ENFORCE(bits_ == 4, "MatMulNBitsQkv currently supports 4-bit weights only."); + ORT_ENFORCE(block_size_ == 32, "MatMulNBitsQkv currently supports block_size=32 only."); TensorShape q_b_shape({Nq_, K_}); MatMulComputeHelper helper; @@ -481,7 +383,9 @@ Status MatMulNBitsQKVSimplifiedLayerNorm::ComputeInternal(onnxruntime::webgpu::C Tensor* q_output = context.Output(0, q_shape); Tensor* k_output = context.Output(1, kv_shape); Tensor* v_output = context.Output(2, kv_shape); - Tensor* input_skip_bias_sum = skip != nullptr ? context.Output(3, a->Shape()) : nullptr; + Tensor* input_skip_bias_sum = (skip != nullptr && context.OutputCount() > 3) + ? context.Output(3, a->Shape()) + : nullptr; if (q_output->Shape().Size() == 0) { return Status::OK(); } @@ -576,6 +480,7 @@ Status MatMulNBitsQKVSimplifiedLayerNorm::ComputeInternal(onnxruntime::webgpu::C uint32_t workgroup_size = 128; uint32_t tile_size = 8; uint32_t tile_size_k_vec = (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u; + if (context.AdapterInfo().vendor != std::string_view{"intel"} && std::max(Nq, Nkv) <= 2048) { workgroup_size = 64; tile_size = 4; @@ -584,13 +489,17 @@ Status MatMulNBitsQKVSimplifiedLayerNorm::ComputeInternal(onnxruntime::webgpu::C const uint32_t elements_in_value_b = components_b * (32u / onnxruntime::narrow(bits_)); const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b; - const bool has_full_q_tiles = (Nq % tile_size) == 0; - const bool has_full_kv_tiles = (Nkv % tile_size) == 0; - const bool has_full_k_tiles = (K % tile_size_k) == 0; const uint32_t k_tile_iterations = K / tile_size_k; + std::optional input_skip_bias_sum_scratch; + Tensor* decode_input_skip_bias_sum = input_skip_bias_sum; + if (skip != nullptr && decode_input_skip_bias_sum == nullptr) { + input_skip_bias_sum_scratch.emplace(context.CreateGPUTensor(a->DataType(), a->Shape())); + decode_input_skip_bias_sum = &*input_skip_bias_sum_scratch; + } + uint32_t k_unroll_tiles = 1; - if (has_full_k_tiles) { + if ((K % tile_size_k) == 0) { if (k_tile_iterations >= 8 && std::max(Nq, Nkv) <= 2048 && context.AdapterInfo().vendor != std::string_view{"intel"}) { k_unroll_tiles = 4; @@ -600,15 +509,12 @@ Status MatMulNBitsQKVSimplifiedLayerNorm::ComputeInternal(onnxruntime::webgpu::C } const uint32_t num_N_tile = CeilDiv(std::max(Nq, Nkv), tile_size); - MatMulNBitsQKVSimplifiedLayerNormDecodeProgram program{tile_size, - single_scale_weights, - tile_size_k_vec, - k_unroll_tiles, - has_full_q_tiles, - has_full_kv_tiles, - has_full_k_tiles, - skip != nullptr, - input_skip_bias_sum != nullptr}; + MatMulNBitsQkvDecodeProgram program{tile_size, + single_scale_weights, + tile_size_k_vec, + k_unroll_tiles, + skip != nullptr, + decode_input_skip_bias_sum != nullptr}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize(num_N_tile, 1, batch_count); program @@ -644,16 +550,12 @@ Status MatMulNBitsQKVSimplifiedLayerNorm::ComputeInternal(onnxruntime::webgpu::C tile_size, tile_size_k_vec, k_unroll_tiles, - has_full_q_tiles, - has_full_kv_tiles, - has_full_k_tiles, single_scale_weights, skip != nullptr, - input_skip_bias_sum != nullptr, + decode_input_skip_bias_sum != nullptr, "decode_qkv_sln"); - - if (input_skip_bias_sum != nullptr) { - program.AddOutput({input_skip_bias_sum, + if (decode_input_skip_bias_sum != nullptr) { + program.AddOutput({decode_input_skip_bias_sum, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.h similarity index 78% rename from onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.h rename to onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.h index 810ffbcdf4885..4d57ab5ac2b3c 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.h @@ -11,9 +11,9 @@ namespace webgpu { using namespace onnxruntime::webgpu; -class MatMulNBitsQKVSimplifiedLayerNorm final : public WebGpuKernel { +class MatMulNBitsQkv final : public WebGpuKernel { public: - explicit MatMulNBitsQKVSimplifiedLayerNorm(const OpKernelInfo& info) : WebGpuKernel(info) { + explicit MatMulNBitsQkv(const OpKernelInfo& info) : WebGpuKernel(info) { K_ = info.GetAttr("K"); Nq_ = info.GetAttr("Nq"); Nkv_ = info.GetAttr("Nkv"); @@ -22,7 +22,7 @@ class MatMulNBitsQKVSimplifiedLayerNorm final : public WebGpuKernel { accuracy_level_ = info.GetAttrOrDefault("accuracy_level", 4); epsilon_ = info.GetAttrOrDefault("epsilon", 1e-6f); ORT_ENFORCE(bits_ == 4, - "MatMulNBitsQKVSimplifiedLayerNorm currently supports 4-bit weights only."); + "MatMulNBitsQkv currently supports 4-bit weights only."); } Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template similarity index 87% rename from onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.wgsl.template rename to onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template index 18d3aa4270c67..fe0dcd7fa7a1b 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv_sln.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template @@ -5,9 +5,6 @@ #param component_a #param component_b #param elements_in_value_b -#param has_full_k_tiles -#param has_full_kv_tiles -#param has_full_q_tiles #param has_skip_input #param has_skip_output #param k_unroll_tiles @@ -73,15 +70,6 @@ fn load_a_vec4(a_offset: u32) -> vec4 { fn loadSHMA(batch: u32, b_global_base: u32, kidx: u32, col: u32, inv_std: f32) { let k_offset = kidx / component_a + col; let input_offset = batch * uniforms.K_of_a + k_offset; -#if has_full_k_tiles - let merged_value = load_merged_input(input_offset); -#if has_skip_output - if (b_global_base == 0u) { - input_skip_bias_sum.setByOffset(input_offset, input_skip_bias_sum_value_t(merged_value)); - } -#endif - tile_A[col] = merged_value * input_a_value_t(input_a_element_t(inv_std)) * norm_scale.getByOffset(k_offset); -#else if (k_offset < uniforms.K_of_a) { let merged_value = load_merged_input(input_offset); #if has_skip_output @@ -93,7 +81,6 @@ fn loadSHMA(batch: u32, b_global_base: u32, kidx: u32, col: u32, inv_std: f32) { } else { tile_A[col] = input_a_value_t(0); } -#endif } fn compute_projection_sum(weight: q_b_value_t, @@ -158,18 +145,7 @@ fn process_k_tile(batch: u32, b_global_base: u32, thread_idx: u32, idx: u32, idy #if !single_scale_weights let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size; #endif - #if has_full_k_tiles - { - #else if (k_offset < uniforms.K_of_b) { - #endif - #if has_full_q_tiles - #if !single_scale_weights - let q_scale_b = q_output_element_t(q_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx)); - #endif - let q_weight = q_b.getByOffset(b_global * uniforms.K_of_b + k_offset); - q_inter_results[local_row_offset + idy][idx] += compute_projection_sum(q_weight, q_scale_b, idx); - #else if (b_global < uniforms.Nq) { #if !single_scale_weights let q_scale_b = q_output_element_t(q_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx)); @@ -177,17 +153,6 @@ fn process_k_tile(batch: u32, b_global_base: u32, thread_idx: u32, idx: u32, idy let q_weight = q_b.getByOffset(b_global * uniforms.K_of_b + k_offset); q_inter_results[local_row_offset + idy][idx] += compute_projection_sum(q_weight, q_scale_b, idx); } - #endif - #if has_full_kv_tiles - #if !single_scale_weights - let k_scale_b = q_output_element_t(k_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx)); - let v_scale_b = q_output_element_t(v_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx)); - #endif - let k_weight = k_b.getByOffset(b_global * uniforms.K_of_b + k_offset); - let v_weight = v_b.getByOffset(b_global * uniforms.K_of_b + k_offset); - k_inter_results[local_row_offset + idy][idx] += compute_projection_sum(k_weight, k_scale_b, idx); - v_inter_results[local_row_offset + idy][idx] += compute_projection_sum(v_weight, v_scale_b, idx); - #else if (b_global < uniforms.Nkv) { #if !single_scale_weights let k_scale_b = q_output_element_t(k_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx)); @@ -198,7 +163,6 @@ fn process_k_tile(batch: u32, b_global_base: u32, thread_idx: u32, idx: u32, idy k_inter_results[local_row_offset + idy][idx] += compute_projection_sum(k_weight, k_scale_b, idx); v_inter_results[local_row_offset + idy][idx] += compute_projection_sum(v_weight, v_scale_b, idx); } - #endif } } workgroupBarrier(); @@ -291,18 +255,10 @@ $MAIN { k_output_value += k_inter_results[local_idx][b]; v_output_value += v_inter_results[local_idx][b]; } -#if has_full_q_tiles - { -#else if (b_global < uniforms.Nq) { -#endif q_output.setByOffset(batch * uniforms.Nq + b_global, q_output_value_t(q_output_value)); } -#if has_full_kv_tiles - { -#else if (b_global < uniforms.Nkv) { -#endif k_output.setByOffset(batch * uniforms.Nkv + b_global, k_output_value_t(k_output_value)); v_output.setByOffset(batch * uniforms.Nkv + b_global, v_output_value_t(v_output_value)); } diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 05d2a4b4d6a37..fa464722ea770 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -22,8 +22,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gr // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBitsQKVSimplifiedLayerNorm); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBitsSiluMul); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBitsQkv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBitsMlp); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding); @@ -52,8 +52,8 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry, bool enable BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index e6a9ed43460c8..4cb5b19f54018 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3616,36 +3616,44 @@ For example, for 4 bits, the first 4 bits are stored in the lower 4 bits of a by } }); - static const char* MatMulNBitsSiluMul_ver1_doc = R"DOC( -MatMulNBitsSiluMul fuses two MatMulNBits projections that share the same input and computes + static const char* MatMulNBitsMlp_ver1_doc = R"DOC( +MatMulNBitsMlp fuses two MatMulNBits projections that share the same input and computes - Y = SiLU(MatMulNBits(A, gate_weight) + gate_bias) * (MatMulNBits(A, up_weight) + up_bias) - -where SiLU(x) = x * sigmoid(x). + gate = MatMulNBits(A, gate_weight) + gate_bias + up = MatMulNBits(A, up_weight) + up_bias + Y = activation(gate) * up It can also optionally fuse SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization before the two projections: A_norm = SimplifiedLayerNormalization(A, norm_scale, epsilon) - Y = SiLU(MatMulNBits(A_norm, gate_weight) + gate_bias) * (MatMulNBits(A_norm, up_weight) + up_bias) + gate = MatMulNBits(A_norm, gate_weight) + gate_bias + up = MatMulNBits(A_norm, up_weight) + up_bias + Y = activation(gate) * up A_norm = SkipSimplifiedLayerNormalization(A, skip, norm_scale, epsilon) - Y = SiLU(MatMulNBits(A_norm, gate_weight) + gate_bias) * (MatMulNBits(A_norm, up_weight) + up_bias) + gate = MatMulNBits(A_norm, gate_weight) + gate_bias + up = MatMulNBits(A_norm, up_weight) + up_bias + Y = activation(gate) * up This operator is intended for decoder MLP patterns such as Qwen-style gate and up projections, but it remains semantically valid for both prefill and decode because the output shape is the standard MatMul result shape derived from the runtime shape of A and the shared attributes K and N. +The operator contract includes a string attribute describing the fused gate activation. + When fused from SkipSimplifiedLayerNormalization, the optional residual-sum output may also be materialized: A_norm, input_skip_bias_sum = SkipSimplifiedLayerNormalization(A, skip, norm_scale, epsilon) - Y = SiLU(MatMulNBits(A_norm, gate_weight) + gate_bias) * (MatMulNBits(A_norm, up_weight) + up_bias) + gate = MatMulNBits(A_norm, gate_weight) + gate_bias + up = MatMulNBits(A_norm, up_weight) + up_bias + Y = activation(gate) * up )DOC"; - ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBitsSiluMul) + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBitsMlp) .SetDomain(kMSDomain) .SinceVersion(1) - .SetDoc(MatMulNBitsSiluMul_ver1_doc) + .SetDoc(MatMulNBitsMlp_ver1_doc) .Attr("K", "Input feature dimension shared by both quantized weight matrices.", AttributeProto::INT) .Attr("N", "Output feature dimension shared by both quantized weight matrices.", AttributeProto::INT) .Attr("bits", "Bit-width used to quantize both weight matrices (valid range: 2~8)", AttributeProto::INT, static_cast(4)) @@ -3655,6 +3663,9 @@ When fused from SkipSimplifiedLayerNormalization, the optional residual-sum outp .Attr("accuracy_level", "The minimum accuracy level of input A. It follows the same semantics as MatMulNBits.", AttributeProto::INT, static_cast(0)) + .Attr("activation", + "Activation applied to the gate projection.", + AttributeProto::STRING) .Input(0, "A", "The shared input tensor.", "T1") .Input(1, "skip", "Optional skip input used by SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional) .Input(2, "norm_scale", "Optional RMSNorm scale with shape [K] used by SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional) @@ -3664,7 +3675,7 @@ When fused from SkipSimplifiedLayerNormalization, the optional residual-sum outp .Input(6, "up_B", "Packed uint8 tensor for the up projection weights.", "T2") .Input(7, "up_scales", "Per-block scaling factors for the up projection.", "T1") .Input(8, "up_bias", "Optional bias for the up projection with shape [N].", "T1", OpSchema::Optional) - .Output(0, "Y", "The fused SiLU-multiply output tensor.", "T1") + .Output(0, "Y", "The fused gated MLP output tensor.", "T1") .Output(1, "input_skip_bias_sum", "Optional residual-sum output for SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional) .TypeConstraint("T1", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") @@ -3727,8 +3738,8 @@ When fused from SkipSimplifiedLayerNormalization, the optional residual-sum outp } }); - static const char* MatMulNBitsQKVSimplifiedLayerNorm_ver1_doc = R"DOC( -MatMulNBitsQKVSimplifiedLayerNorm fuses either SimplifiedLayerNormalization (RMSNorm) + static const char* MatMulNBitsQkv_ver1_doc = R"DOC( +MatMulNBitsQkv fuses either SimplifiedLayerNormalization (RMSNorm) or SkipSimplifiedLayerNormalization with three MatMulNBits projections that share the same normalized activation. @@ -3743,10 +3754,10 @@ and may also return the input+skip residual sum as output 3. This operator is intended as a decode-oriented QKV fusion primitive. )DOC"; - ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBitsQKVSimplifiedLayerNorm) + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBitsQkv) .SetDomain(kMSDomain) .SinceVersion(1) - .SetDoc(MatMulNBitsQKVSimplifiedLayerNorm_ver1_doc) + .SetDoc(MatMulNBitsQkv_ver1_doc) .Attr("K", "Input feature dimension shared by the normalized input and all projection weights.", AttributeProto::INT) .Attr("Nq", "Output feature dimension of the Q projection.", AttributeProto::INT) .Attr("Nkv", "Output feature dimension shared by the K and V projections.", AttributeProto::INT) diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 741781e2d0f18..cd1bcd15cccc2 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -56,8 +56,8 @@ #include "core/optimizer/layer_norm_fusion.h" #include "core/optimizer/matmul_activation_fusion.h" #include "core/optimizer/matmul_add_fusion.h" -#include "core/optimizer/matmul_nbits_qkv_sln_fusion.h" -#include "core/optimizer/matmul_nbits_silu_fusion.h" +#include "core/optimizer/matmul_nbits_qkv_fusion.h" +#include "core/optimizer/matmul_nbits_mlp_fusion.h" #include "core/optimizer/matmul_bn_fusion.h" #include "core/optimizer/matmul_integer_to_float.h" #include "core/optimizer/matmul_scale_fusion.h" @@ -104,22 +104,6 @@ namespace onnxruntime::optimizer_utils { namespace { -constexpr const char* kOrtEnableMatMulNBitsSiluFusionEnvVar = "ORT_ENABLE_MATMUL_NBITS_SILU_FUSION"; -constexpr const char* kOrtEnableMatMulNBitsQKVSimplifiedLayerNormFusionEnvVar = - "ORT_ENABLE_MATMUL_NBITS_QKV_SIMPLIFIED_LAYER_NORM_FUSION"; - -#if !defined(ORT_MINIMAL_BUILD) -bool IsMatMulNBitsSiluFusionEnabled() { - return ParseEnvironmentVariableWithDefault(kOrtEnableMatMulNBitsSiluFusionEnvVar, 0) == 1; - //return true; -} - -bool IsMatMulNBitsQKVSimplifiedLayerNormFusionEnabled() { - return ParseEnvironmentVariableWithDefault(kOrtEnableMatMulNBitsQKVSimplifiedLayerNormFusionEnvVar, 0) == 1; - //return true; -} -#endif - } // namespace static void FilterTransformers(InlinedVector>& transformers, @@ -459,14 +443,10 @@ InlinedVector> GenerateTransformers( #endif transformers.emplace_back(std::make_unique(cpu_ep)); - if (IsMatMulNBitsSiluFusionEnabled()) { - transformers.emplace_back(std::make_unique( - InlinedHashSet{onnxruntime::kWebGpuExecutionProvider})); - } - if (IsMatMulNBitsQKVSimplifiedLayerNormFusionEnabled()) { - transformers.emplace_back(std::make_unique( - InlinedHashSet{onnxruntime::kWebGpuExecutionProvider})); - } + transformers.emplace_back(std::make_unique( + InlinedHashSet{onnxruntime::kWebGpuExecutionProvider})); + transformers.emplace_back(std::make_unique( + InlinedHashSet{onnxruntime::kWebGpuExecutionProvider})); #endif // !defined(DISABLE_CONTRIB_OPS) // The QDQFinalCleanupTransformer must run AFTER other transformers that fuse Q/DQ nodes. Otherwise, their diff --git a/onnxruntime/core/optimizer/matmul_nbits_silu_fusion.cc b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc similarity index 79% rename from onnxruntime/core/optimizer/matmul_nbits_silu_fusion.cc rename to onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc index 28b96a2bae7b6..ee4f1bc63aa3a 100644 --- a/onnxruntime/core/optimizer/matmul_nbits_silu_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc @@ -1,10 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/optimizer/matmul_nbits_silu_fusion.h" +#include "core/optimizer/matmul_nbits_mlp_fusion.h" #include -#include #include "core/graph/graph_utils.h" #include "core/graph/node_attr_utils.h" @@ -14,6 +13,11 @@ namespace onnxruntime { namespace { +constexpr const char* kActivationAttrName = "activation"; +// The transformer name is generic for future expansion, but the current fused +// pattern and emitted op only support gate activation = "silu". +constexpr const char* kSupportedActivation = "silu"; + bool HasInput(const Node& node, size_t index) { return index < node.InputDefs().size() && node.InputDefs()[index] != nullptr && !node.InputDefs()[index]->Name().empty(); } @@ -39,7 +43,7 @@ bool IsSupportedSkipSimplifiedLayerNormalization(const Node& node) { return graph_utils::IsSupportedOptypeVersionAndDomain(node, "SkipSimplifiedLayerNormalization", {1}, kMSDomain); } -bool IsSupportedSiluNormAnchor(const Node& node) { +bool IsSupportedMlpNormAnchor(const Node& node) { return IsSupportedSimplifiedLayerNormalization(node) || IsSupportedSkipSimplifiedLayerNormalization(node); } @@ -63,7 +67,6 @@ bool HasExpectedNormConsumers(const Graph& graph, const Node& node) { return false; } - // Match optimizer_utils::CheckOutputEdges safety check while allowing output 3 to be a graph output. for (auto output_edge_it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); output_edge_it != end; ++output_edge_it) { const auto& output_node = output_edge_it->GetNode(); const auto output_node_input_arg_idx = static_cast(output_edge_it->GetDstArgIndex()); @@ -95,9 +98,9 @@ bool HasSingleNonGraphConsumer(const Graph& graph, const Node& node) { return !graph.NodeProducesGraphOutput(node) && optimizer_utils::CheckOutputEdges(graph, node, 1); } -const Node* GetOptionalNormProducer(const Graph& graph, - const Node& gate_matmul, - const Node& up_matmul) { +const Node* GetNormProducer(const Graph& graph, + const Node& gate_matmul, + const Node& up_matmul) { if (gate_matmul.InputDefs().empty() || up_matmul.InputDefs().empty() || gate_matmul.InputDefs()[0] != up_matmul.InputDefs()[0]) { return nullptr; @@ -105,7 +108,7 @@ const Node* GetOptionalNormProducer(const Graph& graph, const Node* gate_input = GetInputNode(graph, gate_matmul, 0); const Node* up_input = GetInputNode(graph, up_matmul, 0); - if (gate_input == nullptr || gate_input != up_input || !IsSupportedSiluNormAnchor(*gate_input)) { + if (gate_input == nullptr || gate_input != up_input || !IsSupportedMlpNormAnchor(*gate_input)) { return nullptr; } @@ -189,8 +192,8 @@ bool IsFuseCandidate(const Graph& graph, } // namespace -Status MatMulNBitsSiluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, - const logging::Logger& logger) const { +Status MatMulNBitsMlpFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); @@ -252,33 +255,30 @@ Status MatMulNBitsSiluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ continue; } - LOGS(logger, INFO) << "MatMulNBitsSiluFusion: matched candidate final_mul='" << node.Name() - << "' gate='" << gate_matmul->Name() << "' up='" << up_matmul->Name() - << "' sigmoid='" << sigmoid->Name() << "' silu_mul='" << silu_mul->Name() - << "' attrs={K=" << GetIntAttr(*gate_matmul, "K", -1, true) - << ", N=" << GetIntAttr(*gate_matmul, "N", -1, true) - << ", bits=" << GetIntAttr(*gate_matmul, "bits", 4) - << ", block_size=" << GetIntAttr(*gate_matmul, "block_size", -1, true) - << ", accuracy_level=" << GetIntAttr(*gate_matmul, "accuracy_level", 0) - << "}"; - - LOGS(logger, INFO) << "MatMulNBitsSiluFusion: EP state final_mul='" << node.GetExecutionProviderType() - << "' gate='" << gate_matmul->GetExecutionProviderType() - << "' up='" << up_matmul->GetExecutionProviderType() - << "' sigmoid='" << sigmoid->GetExecutionProviderType() - << "' silu_mul='" << silu_mul->GetExecutionProviderType() << "'"; + LOGS(logger, VERBOSE) << "MatMulNBitsMlpFusion: matched candidate output_mul='" << node.Name() + << "' gate='" << gate_matmul->Name() << "' up='" << up_matmul->Name() + << "' sigmoid='" << sigmoid->Name() << "' activation_mul='" << silu_mul->Name() + << "' attrs={K=" << GetIntAttr(*gate_matmul, "K", -1, true) + << ", N=" << GetIntAttr(*gate_matmul, "N", -1, true) + << ", bits=" << GetIntAttr(*gate_matmul, "bits", 4) + << ", block_size=" << GetIntAttr(*gate_matmul, "block_size", -1, true) + << ", accuracy_level=" << GetIntAttr(*gate_matmul, "accuracy_level", 0) + << "}"; if ((!gate_matmul->GetExecutionProviderType().empty() && gate_matmul->GetExecutionProviderType() != kWebGpuExecutionProvider) || (!up_matmul->GetExecutionProviderType().empty() && up_matmul->GetExecutionProviderType() != kWebGpuExecutionProvider) || (!sigmoid->GetExecutionProviderType().empty() && sigmoid->GetExecutionProviderType() != kWebGpuExecutionProvider) || (!silu_mul->GetExecutionProviderType().empty() && silu_mul->GetExecutionProviderType() != kWebGpuExecutionProvider)) { - LOGS(logger, INFO) << "MatMulNBitsSiluFusion: skipping candidate due to non-WebGPU EP assignment."; + LOGS(logger, VERBOSE) << "MatMulNBitsMlpFusion: skipping candidate due to non-WebGPU EP assignment."; + continue; + } + + const Node* norm = GetNormProducer(graph, *gate_matmul, *up_matmul); + if (norm == nullptr) { continue; } - const Node* norm = GetOptionalNormProducer(graph, *gate_matmul, *up_matmul); - if (norm != nullptr && - !norm->GetExecutionProviderType().empty() && norm->GetExecutionProviderType() != kWebGpuExecutionProvider) { + if (!norm->GetExecutionProviderType().empty() && norm->GetExecutionProviderType() != kWebGpuExecutionProvider) { continue; } @@ -288,14 +288,15 @@ Status MatMulNBitsSiluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ utils::SetNodeAttribute(utils::MakeAttribute("bits", GetIntAttr(*gate_matmul, "bits", 4)), attrs); utils::SetNodeAttribute(utils::MakeAttribute("block_size", GetIntAttr(*gate_matmul, "block_size", -1, true)), attrs); utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", GetIntAttr(*gate_matmul, "accuracy_level", 0)), attrs); + utils::SetNodeAttribute(utils::MakeAttribute(kActivationAttrName, std::string{kSupportedActivation}), attrs); NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr); const bool is_skip_sln = norm != nullptr && IsSupportedSkipSimplifiedLayerNormalization(*norm); InlinedVector fused_inputs{ - const_cast(norm != nullptr ? norm->InputDefs()[0] : gate_matmul->InputDefs()[0]), - is_skip_sln ? const_cast(norm->InputDefs()[1]) : &empty_arg, - norm != nullptr ? const_cast(norm->InputDefs()[is_skip_sln ? 2 : 1]) : &empty_arg, + const_cast(norm->InputDefs()[0]), + is_skip_sln ? const_cast(norm->InputDefs()[1]) : &empty_arg, + const_cast(norm->InputDefs()[is_skip_sln ? 2 : 1]), const_cast(gate_matmul->InputDefs()[1]), const_cast(gate_matmul->InputDefs()[2]), HasInput(*gate_matmul, 5) ? const_cast(gate_matmul->InputDefs()[5]) : &empty_arg, @@ -310,18 +311,15 @@ Status MatMulNBitsSiluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ fused_outputs.push_back(const_cast(norm->OutputDefs()[3])); } - const auto norm_input_edges = norm != nullptr ? graph_utils::GraphEdge::GetNodeInputEdges(*norm) - : std::vector{}; + const auto norm_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*norm); const auto gate_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*gate_matmul); const auto up_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*up_matmul); const auto final_mul_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node); const auto norm_output_edges = preserve_skip_output ? graph_utils::GraphEdge::GetNodeOutputEdges(*norm) : std::vector{}; - if (norm != nullptr) { - graph_utils::RemoveNodeOutputEdges(graph, const_cast(*norm)); - graph.RemoveNode(norm->Index()); - } + graph_utils::RemoveNodeOutputEdges(graph, const_cast(*norm)); + graph.RemoveNode(norm->Index()); graph_utils::RemoveNodeOutputEdges(graph, const_cast(*gate_matmul)); graph.RemoveNode(gate_matmul->Index()); graph_utils::RemoveNodeOutputEdges(graph, const_cast(*up_matmul)); @@ -333,33 +331,25 @@ Status MatMulNBitsSiluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ graph_utils::RemoveNodeOutputEdges(graph, node); graph.RemoveNode(node.Index()); - Node& fused_node = graph.AddNode(graph.GenerateNodeName("MatMulNBitsSiluMul"), - "MatMulNBitsSiluMul", - "fused MatMulNBits gate/up projections with SiLU multiply", + Node& fused_node = graph.AddNode(graph.GenerateNodeName("MatMulNBitsMlp"), + "MatMulNBitsMlp", + "fused MatMulNBits gated MLP projections", fused_inputs, fused_outputs, &attrs, kMSDomain); fused_node.SetExecutionProviderType(kWebGpuExecutionProvider); - LOGS(logger, INFO) << "MatMulNBitsSiluFusion: created fused node '" << fused_node.Name() - << "' from final_mul='" << node.Name() << "'"; + LOGS(logger, VERBOSE) << "MatMulNBitsMlpFusion: created fused node '" << fused_node.Name() + << "' from output_mul='" << node.Name() << "'"; - if (norm != nullptr) { - for (const auto& input_edge : norm_input_edges) { - int fused_input_index = input_edge.dst_arg_index; - if (!is_skip_sln && input_edge.dst_arg_index == 1) { - fused_input_index = 2; - } - - graph.AddEdge(input_edge.src_node, fused_node.Index(), input_edge.src_arg_index, fused_input_index); - } - } else { - for (const auto& input_edge : gate_input_edges) { - if (input_edge.dst_arg_index == 0) { - graph.AddEdge(input_edge.src_node, fused_node.Index(), input_edge.src_arg_index, 0); - } + for (const auto& input_edge : norm_input_edges) { + int fused_input_index = input_edge.dst_arg_index; + if (!is_skip_sln && input_edge.dst_arg_index == 1) { + fused_input_index = 2; } + + graph.AddEdge(input_edge.src_node, fused_node.Index(), input_edge.src_arg_index, fused_input_index); } auto add_input_edge_if_present = [&](const std::vector& edges, @@ -396,4 +386,4 @@ Status MatMulNBitsSiluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ return Status::OK(); } -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/matmul_nbits_silu_fusion.h b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h similarity index 57% rename from onnxruntime/core/optimizer/matmul_nbits_silu_fusion.h rename to onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h index d2c84dc2a3983..2208df0f3d3e4 100644 --- a/onnxruntime/core/optimizer/matmul_nbits_silu_fusion.h +++ b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h @@ -7,10 +7,10 @@ namespace onnxruntime { -class MatMulNBitsSiluFusion : public GraphTransformer { +class MatMulNBitsMlpFusion : public GraphTransformer { public: - explicit MatMulNBitsSiluFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("MatMulNBitsSiluFusion", compatible_execution_providers) {} + explicit MatMulNBitsMlpFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("MatMulNBitsMlpFusion", compatible_execution_providers) {} private: Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; diff --git a/onnxruntime/core/optimizer/matmul_nbits_qkv_sln_fusion.cc b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc similarity index 90% rename from onnxruntime/core/optimizer/matmul_nbits_qkv_sln_fusion.cc rename to onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc index 1e8bc2346c8d7..1eeb38058f34c 100644 --- a/onnxruntime/core/optimizer/matmul_nbits_qkv_sln_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/optimizer/matmul_nbits_qkv_sln_fusion.h" +#include "core/optimizer/matmul_nbits_qkv_fusion.h" #include #include @@ -156,8 +156,8 @@ bool IsFuseCandidate(const Node& norm, const QkvNodes& qkv) { } // namespace -Status MatMulNBitsQKVSimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, - const logging::Logger& logger) const { +Status MatMulNBitsQkvFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); @@ -189,12 +189,12 @@ Status MatMulNBitsQKVSimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& mo const bool is_skip_sln = IsSupportedSkipSimplifiedLayerNormalization(node); - LOGS(logger, INFO) << "MatMulNBitsQKVSimplifiedLayerNormFusion: matched norm='" << node.Name() - << "' q='" << qkv_nodes->q->Name() << "' k='" << qkv_nodes->k->Name() - << "' v='" << qkv_nodes->v->Name() << "' attrs={K=" << K - << ", Nq=" << Nq << ", Nkv=" << Nkv << ", bits=" << bits - << ", block_size=" << block_size << ", accuracy_level=" << accuracy_level - << ", epsilon=" << epsilon << ", skip_sln=" << is_skip_sln << "}"; + LOGS(logger, VERBOSE) << "MatMulNBitsQkvFusion: matched norm='" << node.Name() + << "' q='" << qkv_nodes->q->Name() << "' k='" << qkv_nodes->k->Name() + << "' v='" << qkv_nodes->v->Name() << "' attrs={K=" << K + << ", Nq=" << Nq << ", Nkv=" << Nkv << ", bits=" << bits + << ", block_size=" << block_size << ", accuracy_level=" << accuracy_level + << ", epsilon=" << epsilon << ", skip_sln=" << is_skip_sln << "}"; NodeAttributes attrs; utils::SetNodeAttribute(utils::MakeAttribute("K", K), attrs); @@ -209,7 +209,7 @@ Status MatMulNBitsQKVSimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& mo InlinedVector fused_inputs{ const_cast(node.InputDefs()[0]), - is_skip_sln ? const_cast(node.InputDefs()[1]) : &empty_arg, + is_skip_sln ? const_cast(node.InputDefs()[1]) : &empty_arg, const_cast(node.InputDefs()[is_skip_sln ? 2 : 1]), const_cast(qkv_nodes->q->InputDefs()[1]), const_cast(qkv_nodes->q->InputDefs()[2]), @@ -235,7 +235,6 @@ Status MatMulNBitsQKVSimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& mo const auto norm_output_edges = is_skip_sln && HasProducedOutput(node, 3) ? graph_utils::GraphEdge::GetNodeOutputEdges(node) : std::vector{}; - graph_utils::RemoveNodeOutputEdges(graph, const_cast(*qkv_nodes->q)); graph.RemoveNode(qkv_nodes->q->Index()); graph_utils::RemoveNodeOutputEdges(graph, const_cast(*qkv_nodes->k)); @@ -245,8 +244,8 @@ Status MatMulNBitsQKVSimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& mo graph_utils::RemoveNodeOutputEdges(graph, node); graph.RemoveNode(node.Index()); - Node& fused_node = graph.AddNode(graph.GenerateNodeName("MatMulNBitsQKVSimplifiedLayerNorm"), - "MatMulNBitsQKVSimplifiedLayerNorm", + Node& fused_node = graph.AddNode(graph.GenerateNodeName("MatMulNBitsQkv"), + "MatMulNBitsQkv", "fused SimplifiedLayerNormalization with Q/K/V MatMulNBits projections", fused_inputs, fused_outputs, @@ -254,6 +253,10 @@ Status MatMulNBitsQKVSimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& mo kMSDomain); fused_node.SetExecutionProviderType(kWebGpuExecutionProvider); + LOGS(logger, VERBOSE) << "MatMulNBitsQkvFusion: created fused node '" << fused_node.Name() + << "' from norm='" << node.Name() << "' q='" << qkv_nodes->q->Name() + << "' k='" << qkv_nodes->k->Name() << "' v='" << qkv_nodes->v->Name() << "'"; + for (const auto& input_edge : norm_input_edges) { int fused_input_index = input_edge.dst_arg_index; if (!is_skip_sln && input_edge.dst_arg_index == 1) { diff --git a/onnxruntime/core/optimizer/matmul_nbits_qkv_sln_fusion.h b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.h similarity index 65% rename from onnxruntime/core/optimizer/matmul_nbits_qkv_sln_fusion.h rename to onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.h index a40cc98459818..2e028c28190b4 100644 --- a/onnxruntime/core/optimizer/matmul_nbits_qkv_sln_fusion.h +++ b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.h @@ -7,11 +7,11 @@ namespace onnxruntime { -class MatMulNBitsQKVSimplifiedLayerNormFusion : public GraphTransformer { +class MatMulNBitsQkvFusion : public GraphTransformer { public: - explicit MatMulNBitsQKVSimplifiedLayerNormFusion( + explicit MatMulNBitsQkvFusion( const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("MatMulNBitsQKVSimplifiedLayerNormFusion", compatible_execution_providers) {} + : GraphTransformer("MatMulNBitsQkvFusion", compatible_execution_providers) {} private: Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc index 3e1b87821fe2f..7565cc8d52a87 100644 --- a/onnxruntime/core/providers/webgpu/allocator.cc +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -9,14 +9,18 @@ namespace onnxruntime { namespace webgpu { GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator) + : GpuBufferAllocator([buffer_manager_ptr = &buffer_manager]() -> const BufferManager& { return *buffer_manager_ptr; }, is_read_only_allocator) { +} + +GpuBufferAllocator::GpuBufferAllocator(std::function buffer_manager_getter, bool is_read_only_allocator) : IAllocator( OrtMemoryInfo(WEBGPU_BUFFER, is_read_only_allocator ? OrtAllocatorType::OrtReadOnlyAllocator : OrtAllocatorType::OrtDeviceAllocator, WebGpuDevice, OrtMemTypeDefault)), - buffer_manager_{buffer_manager}, - mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} { + buffer_manager_getter_{std::move(buffer_manager_getter)}, + mapped_at_creation_{is_read_only_allocator && buffer_manager_getter_().SupportsUMA()} { } void* GpuBufferAllocator::Alloc(size_t size) { @@ -26,15 +30,17 @@ void* GpuBufferAllocator::Alloc(size_t size) { stats_.num_allocs++; + const auto& buffer_manager = buffer_manager_getter_(); + wgpu::BufferUsage usage = mapped_at_creation_ ? wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapWrite : wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Indirect; - return buffer_manager_.Create(size, usage); + return buffer_manager.Create(size, usage); } void GpuBufferAllocator::Free(void* p) { if (p != nullptr) { - buffer_manager_.Release(static_cast(p)); + buffer_manager_getter_().Release(static_cast(p)); stats_.num_allocs--; } } diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h index 74b3d669fcf3b..fadfc8c86cfc4 100644 --- a/onnxruntime/core/providers/webgpu/allocator.h +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -3,6 +3,8 @@ #pragma once +#include + #include "core/framework/allocator.h" #include "core/framework/ortdevice.h" @@ -19,6 +21,7 @@ inline constexpr OrtDevice WebGpuDevice{OrtDevice::GPU, class GpuBufferAllocator : public IAllocator { public: GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator); + GpuBufferAllocator(std::function buffer_manager_getter, bool is_read_only_allocator); virtual void* Alloc(size_t size) override; virtual void Free(void* p) override; @@ -26,7 +29,7 @@ class GpuBufferAllocator : public IAllocator { private: AllocatorStats stats_; - const BufferManager& buffer_manager_; + std::function buffer_manager_getter_; bool mapped_at_creation_; }; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 86c40c3b93750..fb6da131e45fe 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -165,7 +165,7 @@ class WebGpuContextFactory { class WebGpuContext final { public: Status Wait(wgpu::Future f); - Status WaitForQueueIdle(); + Status WaitForQueueIdle(); const wgpu::Device& Device() const { return device_; } diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 81e4a22892254..dba8e2bb4da7f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -940,16 +940,6 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, enable_int64_{config.enable_graph_capture || config.enable_int64}, multi_rotary_cache_concat_offset_{config.multi_rotary_cache_concat_offset}, prepack_allocator_{std::make_shared(context_.InitializerBufferManager(), false)} { - // If graph capture is enabled, create a dedicated buffer manager for graph mode - if (enable_graph_capture_) { - // Create buffer manager for graph capture mode with appropriate cache modes - graph_buffer_mgr_ = webgpu::BufferManagerFactory::Create( - context_, - webgpu::BufferCacheMode::Graph, - webgpu::BufferCacheMode::GraphSimple, - webgpu::BufferCacheMode::Disabled); - } - if (config.enable_pix_capture) { #if defined(ENABLE_PIX_FOR_WEBGPU_EP) // set pix frame generator @@ -965,7 +955,7 @@ std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { // allocator for initializers std::make_unique(context_.InitializerBufferManager(), true), // default allocator - std::make_unique(BufferManager(), false), + std::make_unique([this]() -> const webgpu::BufferManager& { return BufferManager(); }, false), }; } @@ -1130,6 +1120,14 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_op } if (graph_annotation_id != -1 && IsGraphCaptureAllowed() && !IsGraphCaptured(graph_annotation_id)) { + if (!graph_buffer_mgr_) { + graph_buffer_mgr_ = webgpu::BufferManagerFactory::Create( + context_, + webgpu::BufferCacheMode::Graph, + webgpu::BufferCacheMode::GraphSimple, + webgpu::BufferCacheMode::Disabled); + } + graph_buffer_mgr_active_ = true; context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_); } m_current_graph_annotation_id = graph_annotation_id; @@ -1151,6 +1149,8 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti } } + graph_buffer_mgr_active_ = false; + if (session_profiler_ && session_profiler_->Enabled()) { // Session-level profiling: collect into profiler's own events storage. context_.CollectProfilingData(session_profiler_->GpuEvents()); @@ -1182,6 +1182,7 @@ bool WebGpuExecutionProvider::IsGraphCaptured(int graph_annotation_id) const { Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) { ORT_ENFORCE(IsGraphCaptured(graph_annotation_id)); + ORT_ENFORCE(graph_buffer_mgr_ != nullptr, "Graph buffer manager must exist before replay."); // TODO: enable profiling in run level if (session_profiler_ && session_profiler_->Enabled()) { context_.StartProfiling(); @@ -1195,7 +1196,7 @@ Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) { } webgpu::BufferManager& WebGpuExecutionProvider::BufferManager() const { - if (graph_buffer_mgr_) { + if (graph_buffer_mgr_active_ && graph_buffer_mgr_) { return *graph_buffer_mgr_; } else { return context_.BufferManager(); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index b46d3f3cb45d2..4a68963092071 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -124,6 +124,7 @@ class WebGpuExecutionProvider : public IExecutionProvider { bool enable_int64_ = false; uint32_t multi_rotary_cache_concat_offset_ = 0; bool is_graph_captured_ = false; + bool graph_buffer_mgr_active_ = false; int regular_run_count_before_graph_capture_ = 0; const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. int m_current_graph_annotation_id = 0; diff --git a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc index 7bc5c19c191c5..a95faf04f37d3 100644 --- a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc +++ b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include #include "core/providers/webgpu/webgpu_provider_options.h" @@ -33,8 +34,9 @@ extern const OrtApi* g_ort; namespace { constexpr const char* kMatMulNBitsAutoTunerEnvVar = "ORT_WEBGPU_MATMUL_NBITS_ENABLE_AUTO_TUNER"; constexpr const char* kDecodeBenchmarkModeEnvVar = "ORT_WEBGPU_MATMUL_NBITS_BENCHMARK_MODE"; -constexpr const char* kDecodeBenchmarkGpuEnvVar = "ORT_WEBGPU_MATMUL_NBITS_BENCHMARK_GPU"; constexpr const char* kDecodeBenchmarkGraphCaptureEnvVar = "ORT_WEBGPU_MATMUL_NBITS_ENABLE_GRAPH_CAPTURE"; +constexpr const char* kDecodeBenchmarkOptimizedModelPathEnvVar = "ORT_WEBGPU_MATMUL_NBITS_OPTIMIZED_MODEL_PATH"; +constexpr const char* kDecodeBenchmarkVerboseSessionLogEnvVar = "ORT_WEBGPU_MATMUL_NBITS_VERBOSE_SESSION_LOG"; constexpr float kDecodeCorrectnessAbsTolerance = 0.1f; constexpr float kDecodeCorrectnessRelTolerance = 0.01f; constexpr const char* kBenchmarkGraphCaptureAnnotationId = "1"; @@ -44,32 +46,20 @@ enum class DecodeBenchmarkMode { kCorrectness, }; -enum class DecodeBenchmarkGpu { - kRtx5060Ti, - kT1000, -}; - bool IsMatMulNBitsAutoTunerEnabled(); bool IsGraphCaptureBenchmarkEnabled(); - -struct DecodeBenchConfig { - int64_t n; - int64_t k; - int64_t bits; - int64_t block_size; - int64_t accuracy_level; -}; +bool IsVerboseSessionLogEnabled(); +std::string GetOptimizedModelPath(); enum class MlpDecodeBenchmarkVariant { kUnfused, kFused, - kSkipNormThenFused, - kSkipNormPassthroughThenFused, }; enum class MlpNormKind { - kNone, + kSimplified, kSkipSimplified, + kSkipSimplifiedPassthrough, }; struct MlpDecodeBenchConfig { @@ -81,18 +71,11 @@ struct MlpDecodeBenchConfig { }; struct AdapterSelectionConfig { - // adapter_type: Dawn adapter type to select, e.g. integrated or discrete GPU. - // preferred_vendor_id/device_id: stable PCI identifiers used to locate the target GPU regardless of enumeration order. - // preferred_device_substring: fallback name match if device IDs are unavailable or change. - // adapter_index: fallback zero-based index among only adapters of adapter_type if the preferred adapter is not found. + // preferred_device_substring: optional case-insensitive device-name hint. // context_id: ORT WebGPU custom context ID used to bind the externally created instance/device. // backend_type: Dawn backend to enumerate adapters from, e.g. D3D12 or Vulkan. // print_adapter_list: whether to print all discovered adapters before selecting one. - WGPUAdapterType adapter_type; - uint32_t preferred_vendor_id; - uint32_t preferred_device_id; const char* preferred_device_substring; - int adapter_index; int context_id; WGPUBackendType backend_type; bool print_adapter_list; @@ -119,18 +102,8 @@ struct SelectedWebGpuContext { std::string selected_adapter_summary; }; -struct DecodeTrafficStats { - double input_bytes; - double packed_weight_bytes; - double scale_bytes; - double output_bytes; - double total_bytes; -}; - struct MlpTrafficStats { double input_bytes; - double skip_input_bytes; - double norm_scale_bytes; double packed_weight_bytes; double scale_bytes; double intermediate_bytes; @@ -155,6 +128,7 @@ enum class QkvDecodeBenchmarkVariant { enum class QkvNormKind { kSimplified, kSkipSimplified, + kSkipSimplifiedPassthrough, }; struct QkvTrafficStats { @@ -168,7 +142,7 @@ struct QkvTrafficStats { double total_bytes; }; -constexpr double kRtx5060TiTheoreticalBandwidthBytesPerSecond = 448.0 * 1000.0 * 1000.0 * 1000.0; +constexpr double kRtxTheoreticalBandwidthBytesPerSecond = 448.0 * 1000.0 * 1000.0 * 1000.0; constexpr int kDecodeWarmupRuns = 25; DecodeBenchmarkMode GetDecodeBenchmarkMode() { @@ -191,29 +165,17 @@ bool IsDecodeBenchmarkPerfMode() { return GetDecodeBenchmarkMode() == DecodeBenchmarkMode::kPerf; } -DecodeBenchmarkGpu GetDecodeBenchmarkGpu() { - std::string gpu_env = onnxruntime::Env::Default().GetEnvironmentVar(kDecodeBenchmarkGpuEnvVar); - if (gpu_env.empty()) { - return DecodeBenchmarkGpu::kRtx5060Ti; - } - - std::transform(gpu_env.begin(), gpu_env.end(), gpu_env.begin(), - [](unsigned char value) { return static_cast(std::tolower(value)); }); - if (gpu_env == "t" || gpu_env == "t1000") { - return DecodeBenchmarkGpu::kT1000; - } - - return DecodeBenchmarkGpu::kRtx5060Ti; -} - -std::string GetDecodeBenchmarkLabel() { +std::string GetDecodeBenchmarkLabel(const char* shape_label = nullptr) { const char* mode_label = IsDecodeBenchmarkPerfMode() ? "perf" : "correctness"; - const char* adapter_label = GetDecodeBenchmarkGpu() == DecodeBenchmarkGpu::kRtx5060Ti ? "rtx" : "t"; const char* tuner_label = IsMatMulNBitsAutoTunerEnabled() ? "tuner_on" : "tuner_off"; const char* graph_label = IsGraphCaptureBenchmarkEnabled() ? "graph_on" : "graph_off"; std::ostringstream stream; - stream << "fp16_decode_" << mode_label << '_' << adapter_label << '_' << tuner_label << '_' << graph_label; + stream << "fp16_decode"; + if (shape_label != nullptr && shape_label[0] != '\0') { + stream << '_' << shape_label; + } + stream << '_' << mode_label << "_auto_gpu_" << tuner_label << '_' << graph_label; return stream.str(); } @@ -239,6 +201,21 @@ bool IsGraphCaptureBenchmarkEnabled() { return graph_capture_env != "0" && graph_capture_env != "false" && graph_capture_env != "off"; } +bool IsVerboseSessionLogEnabled() { + std::string verbose_log_env = onnxruntime::Env::Default().GetEnvironmentVar(kDecodeBenchmarkVerboseSessionLogEnvVar); + if (verbose_log_env.empty()) { + return false; + } + + std::transform(verbose_log_env.begin(), verbose_log_env.end(), verbose_log_env.begin(), + [](unsigned char value) { return static_cast(std::tolower(value)); }); + return verbose_log_env != "0" && verbose_log_env != "false" && verbose_log_env != "off"; +} + +std::string GetOptimizedModelPath() { + return onnxruntime::Env::Default().GetEnvironmentVar(kDecodeBenchmarkOptimizedModelPathEnvVar); +} + Ort::RunOptions CreateBenchmarkRunOptions() { Ort::RunOptions run_options; if (IsGraphCaptureBenchmarkEnabled()) { @@ -341,123 +318,86 @@ std::string FormatFeatureSupport(const dawn::native::Adapter& adapter) { return stream.str(); } -DecodeTrafficStats CalculateDecodeTrafficStats(const DecodeBenchConfig& config) { - const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; - const int64_t blob_size = (config.block_size * config.bits) / 8; - - const double input_bytes = static_cast(config.k) * sizeof(Ort::Float16_t); - const double packed_weight_bytes = static_cast(config.n) * static_cast(k_blocks) * static_cast(blob_size); - const double scale_bytes = static_cast(config.n) * static_cast(k_blocks) * sizeof(Ort::Float16_t); - const double output_bytes = static_cast(config.n) * sizeof(Ort::Float16_t); - - return { - input_bytes, - packed_weight_bytes, - scale_bytes, - output_bytes, - input_bytes + packed_weight_bytes + scale_bytes + output_bytes, - }; +std::string ToLower(std::string value) { + std::transform(value.begin(), value.end(), value.begin(), + [](unsigned char character) { return static_cast(std::tolower(character)); }); + return value; } MlpTrafficStats CalculateMlpTrafficStats(const MlpDecodeBenchConfig& config, - MlpDecodeBenchmarkVariant variant, - MlpNormKind norm_kind) { + MlpDecodeBenchmarkVariant variant, + MlpNormKind norm_kind) { const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; const int64_t blob_size = (config.block_size * config.bits) / 8; const bool is_unfused = variant == MlpDecodeBenchmarkVariant::kUnfused; - const bool is_skip_norm_then_fused = variant == MlpDecodeBenchmarkVariant::kSkipNormThenFused; - const bool is_skip_norm_passthrough_then_fused = variant == MlpDecodeBenchmarkVariant::kSkipNormPassthroughThenFused; - const bool has_skip_norm = norm_kind == MlpNormKind::kSkipSimplified; const double input_reads = variant == MlpDecodeBenchmarkVariant::kUnfused ? 2.0 : 1.0; - const double intermediate_bytes = - (is_unfused ? 4.0 * static_cast(config.n) * sizeof(Ort::Float16_t) : 0.0) + - ((is_unfused || is_skip_norm_then_fused) && has_skip_norm - ? static_cast(config.k) * sizeof(Ort::Float16_t) - : 0.0); + const bool has_skip = norm_kind == MlpNormKind::kSkipSimplified || + norm_kind == MlpNormKind::kSkipSimplifiedPassthrough; + const double skip_input_bytes = has_skip ? static_cast(config.k) * sizeof(Ort::Float16_t) : 0.0; + const double norm_scale_bytes = static_cast(config.k) * sizeof(Ort::Float16_t); + const double intermediate_bytes = is_unfused ? 4.0 * static_cast(config.n) * sizeof(Ort::Float16_t) : 0.0; const double input_bytes = input_reads * static_cast(config.k) * sizeof(Ort::Float16_t); - const double skip_input_bytes = - has_skip_norm ? static_cast(config.k) * sizeof(Ort::Float16_t) : 0.0; - const double norm_scale_bytes = - has_skip_norm ? static_cast(config.k) * sizeof(Ort::Float16_t) : 0.0; const double packed_weight_bytes = 2.0 * static_cast(config.n) * static_cast(k_blocks) * static_cast(blob_size); const double scale_bytes = 2.0 * static_cast(config.n) * static_cast(k_blocks) * sizeof(Ort::Float16_t); const double output_bytes = - (static_cast(config.n) + - (is_skip_norm_passthrough_then_fused && has_skip_norm ? static_cast(config.k) : 0.0)) * + static_cast(config.n + (norm_kind == MlpNormKind::kSkipSimplifiedPassthrough ? config.k : 0)) * sizeof(Ort::Float16_t); return { input_bytes, - skip_input_bytes, - norm_scale_bytes, packed_weight_bytes, scale_bytes, intermediate_bytes, output_bytes, - input_bytes + skip_input_bytes + norm_scale_bytes + packed_weight_bytes + scale_bytes + intermediate_bytes + output_bytes, + input_bytes + skip_input_bytes + norm_scale_bytes + packed_weight_bytes + scale_bytes + intermediate_bytes + output_bytes, }; } QkvTrafficStats CalculateQkvTrafficStats(const QkvDecodeBenchConfig& config, - QkvDecodeBenchmarkVariant variant, - QkvNormKind norm_kind) { + QkvDecodeBenchmarkVariant variant, + QkvNormKind norm_kind) { const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; const int64_t blob_size = (config.block_size * config.bits) / 8; + const bool has_skip = norm_kind == QkvNormKind::kSkipSimplified || + norm_kind == QkvNormKind::kSkipSimplifiedPassthrough; + const bool has_skip_passthrough = norm_kind == QkvNormKind::kSkipSimplifiedPassthrough; + const double input_bytes = static_cast(config.k) * sizeof(Ort::Float16_t); - const double skip_input_bytes = norm_kind == QkvNormKind::kSkipSimplified - ? static_cast(config.k) * sizeof(Ort::Float16_t) - : 0.0; + const double skip_input_bytes = has_skip ? static_cast(config.k) * sizeof(Ort::Float16_t) : 0.0; const double norm_scale_bytes = static_cast(config.k) * sizeof(Ort::Float16_t); const double packed_weight_bytes = static_cast(config.q_n + 2 * config.kv_n) * static_cast(k_blocks) * static_cast(blob_size); const double scale_bytes = static_cast(config.q_n + 2 * config.kv_n) * static_cast(k_blocks) * sizeof(Ort::Float16_t); const double intermediate_bytes = - variant == QkvDecodeBenchmarkVariant::kUnfused ? static_cast(config.k) * sizeof(Ort::Float16_t) : 0.0; + variant == QkvDecodeBenchmarkVariant::kUnfused ? static_cast(config.k) * sizeof(Ort::Float16_t) : 0.0; const double output_bytes = - static_cast(config.q_n + 2 * config.kv_n + (norm_kind == QkvNormKind::kSkipSimplified ? config.k : 0)) * - sizeof(Ort::Float16_t); + static_cast(config.q_n + 2 * config.kv_n + (has_skip_passthrough ? config.k : 0)) * + sizeof(Ort::Float16_t); return { input_bytes, - skip_input_bytes, + skip_input_bytes, norm_scale_bytes, packed_weight_bytes, scale_bytes, intermediate_bytes, output_bytes, - input_bytes + skip_input_bytes + norm_scale_bytes + packed_weight_bytes + scale_bytes + intermediate_bytes + output_bytes, + input_bytes + skip_input_bytes + norm_scale_bytes + packed_weight_bytes + scale_bytes + intermediate_bytes + output_bytes, }; } AdapterSelectionConfig GetAdapterSelectionConfig() { - if (GetDecodeBenchmarkGpu() == DecodeBenchmarkGpu::kT1000) { - return { - WGPUAdapterType_DiscreteGPU, // adapter_type - 4318, // preferred_vendor_id (NVIDIA) - 8112, // preferred_device_id (T1000) - "T1000", // preferred_device_substring - 0, // adapter_index fallback - 0, // context_id - WGPUBackendType_D3D12, // backend_type - true, // print_adapter_list - }; - } - - // Prefer the RTX 5060 Ti by stable PCI identity so selection does not depend on - // Dawn enumeration order. Fall back to the historical second discrete adapter. + // Prefer a 5060 Ti when Dawn exposes one, otherwise fall back to the first + // Dawn-enumerated adapter so the benchmark remains robust across machines. return { - WGPUAdapterType_DiscreteGPU, // adapter_type - 4318, // preferred_vendor_id (NVIDIA) - 11524, // preferred_device_id (RTX 5060 Ti) - "RTX 5060 Ti", // preferred_device_substring - 1, // adapter_index fallback - 1, // context_id - WGPUBackendType_D3D12, // backend_type - true, // print_adapter_list + "5060 Ti", // preferred_device_substring + 1, // context_id + WGPUBackendType_D3D12, // backend_type + true, // print_adapter_list }; } @@ -552,44 +492,22 @@ SelectedWebGpuContext CreateSelectedWebGpuContext() { } AdapterCandidate* selected_adapter = nullptr; - for (auto& candidate : candidates) { - if (candidate.adapter_type == config.adapter_type && - candidate.vendor_id == config.preferred_vendor_id && - candidate.device_id == config.preferred_device_id) { - selected_adapter = &candidate; - break; - } - } - - if (selected_adapter == nullptr && config.preferred_device_substring != nullptr) { + if (config.preferred_device_substring != nullptr) { + const std::string preferred_substring = ToLower(config.preferred_device_substring); for (auto& candidate : candidates) { - if (candidate.adapter_type == config.adapter_type && - candidate.device.find(config.preferred_device_substring) != std::string::npos) { + if (ToLower(candidate.device).find(preferred_substring) != std::string::npos) { selected_adapter = &candidate; break; } } } - if (selected_adapter == nullptr) { - for (auto& candidate : candidates) { - if (candidate.adapter_type == config.adapter_type && - candidate.type_index == config.adapter_index) { - selected_adapter = &candidate; - break; - } - } + if (selected_adapter == nullptr && !candidates.empty()) { + selected_adapter = &candidates.front(); } if (selected_adapter == nullptr) { - std::ostringstream stream; - stream << "Failed to find preferred " << AdapterTypeToString(config.adapter_type) - << " adapter vendor_id=" << config.preferred_vendor_id - << " device_id=" << config.preferred_device_id - << " name~=" << (config.preferred_device_substring ? config.preferred_device_substring : "") - << ", or fallback adapter index " << config.adapter_index - << ". Update GetAdapterSelectionConfig() to match the available adapters listed above."; - throw std::runtime_error(stream.str()); + throw std::runtime_error("No Dawn adapter candidates were available for WebGPU benchmark selection."); } const wgpu::Adapter adapter = selected_adapter->adapter.Get(); @@ -657,24 +575,6 @@ void AddTensorValueInfo(ONNX_NAMESPACE::GraphProto& graph, } } -std::vector GetDecodeBenchConfigs() { - // Each entry is {N, K, bits, block_size, accuracy_level} for a decode-style M=1 run. - return { - // QKV + AttnProj - {1024, 2048, 4, 32, 4}, - {2048, 2048, 4, 32, 4}, - - // Gate + Up proj - {6144, 2048, 4, 32, 4}, - - // Down proj - {2048, 6144, 4, 32, 4}, - - // Vocab proj - {151936, 2048, 4, 32, 4}, - }; -} - std::vector GetMlpDecodeBenchConfigs() { // Qwen3-1.7B MLP gate/up decode geometry: hidden=2048, intermediate=6144. return { @@ -737,25 +637,25 @@ void AddMatMulNBitsNode(ONNX_NAMESPACE::GraphProto& graph, attr_accuracy->set_i(accuracy_level); } -void AddMatMulNBitsSiluMulNode(ONNX_NAMESPACE::GraphProto& graph, - const std::string& node_name, - const std::string& input_name, - const std::string& skip_input_name, - const std::string& norm_scale_name, - const std::string& gate_weight_name, - const std::string& gate_scale_name, - const std::string& up_weight_name, - const std::string& up_scale_name, - const std::string& output_name, - const std::string& skip_sum_output_name, - int64_t k, - int64_t n, - int64_t bits, - int64_t block_size, - int64_t accuracy_level) { +void AddMatMulNBitsMlpNode(ONNX_NAMESPACE::GraphProto& graph, + const std::string& node_name, + const std::string& input_name, + const std::string& skip_input_name, + const std::string& norm_scale_name, + const std::string& gate_weight_name, + const std::string& gate_scale_name, + const std::string& up_weight_name, + const std::string& up_scale_name, + const std::string& output_name, + const std::string& skip_sum_output_name, + int64_t k, + int64_t n, + int64_t bits, + int64_t block_size, + int64_t accuracy_level) { auto* node = graph.add_node(); node->set_name(node_name); - node->set_op_type("MatMulNBitsSiluMul"); + node->set_op_type("MatMulNBitsMlp"); node->set_domain("com.microsoft"); node->add_input(input_name); node->add_input(skip_input_name); @@ -795,33 +695,38 @@ void AddMatMulNBitsSiluMulNode(ONNX_NAMESPACE::GraphProto& graph, attr_accuracy->set_name("accuracy_level"); attr_accuracy->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); attr_accuracy->set_i(accuracy_level); + + auto* attr_activation = node->add_attribute(); + attr_activation->set_name("activation"); + attr_activation->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRING); + attr_activation->set_s("silu"); } -void AddMatMulNBitsQKVSimplifiedLayerNormNode(ONNX_NAMESPACE::GraphProto& graph, - const std::string& node_name, - const std::string& input_name, - const std::string& skip_input_name, - const std::string& norm_scale_name, - const std::string& q_weight_name, - const std::string& q_scale_name, - const std::string& k_weight_name, - const std::string& k_scale_name, - const std::string& v_weight_name, - const std::string& v_scale_name, - const std::string& q_output_name, - const std::string& k_output_name, - const std::string& v_output_name, - const std::string& skip_sum_output_name, - int64_t k, - int64_t q_n, - int64_t kv_n, - int64_t bits, - int64_t block_size, - int64_t accuracy_level, - float epsilon) { +void AddMatMulNBitsQkvNode(ONNX_NAMESPACE::GraphProto& graph, + const std::string& node_name, + const std::string& input_name, + const std::string& skip_input_name, + const std::string& norm_scale_name, + const std::string& q_weight_name, + const std::string& q_scale_name, + const std::string& k_weight_name, + const std::string& k_scale_name, + const std::string& v_weight_name, + const std::string& v_scale_name, + const std::string& q_output_name, + const std::string& k_output_name, + const std::string& v_output_name, + const std::string& skip_sum_output_name, + int64_t k, + int64_t q_n, + int64_t kv_n, + int64_t bits, + int64_t block_size, + int64_t accuracy_level, + float epsilon) { auto* node = graph.add_node(); node->set_name(node_name); - node->set_op_type("MatMulNBitsQKVSimplifiedLayerNorm"); + node->set_op_type("MatMulNBitsQkv"); node->set_domain("com.microsoft"); node->add_input(input_name); node->add_input(skip_input_name); @@ -869,83 +774,35 @@ void AddMatMulNBitsQKVSimplifiedLayerNormNode(ONNX_NAMESPACE::GraphProto& graph, attr_epsilon->set_f(epsilon); } -std::vector SerializeMatMulNBitsModel(const DecodeBenchConfig& config) { - const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; - const int64_t blob_size = (config.block_size * config.bits) / 8; - - ONNX_NAMESPACE::ModelProto model; - model.set_ir_version(10); - - auto* onnx_opset = model.add_opset_import(); - onnx_opset->set_domain(""); - onnx_opset->set_version(21); - auto* ms_opset = model.add_opset_import(); - ms_opset->set_domain("com.microsoft"); - ms_opset->set_version(1); - - auto* graph = model.mutable_graph(); - graph->set_name("WebGpuMatMulNBitsDecode"); - - auto* input = graph->add_input(); - input->set_name("A"); - input->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); - input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); - input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.k); - - auto* output = graph->add_output(); - output->set_name("Y"); - output->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); - output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); - output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.n); - - std::vector packed_b(static_cast(config.n * k_blocks * blob_size), uint8_t{0x11}); - std::vector scales(static_cast(config.n * k_blocks), Ort::Float16_t(0.03125f)); - - AddTensorInitializer(*graph, "B", ONNX_NAMESPACE::TensorProto_DataType_UINT8, - {config.n, k_blocks, blob_size}, packed_b); - AddTensorInitializer(*graph, "scales", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - {config.n, k_blocks}, scales); - - AddMatMulNBitsNode(*graph, - "MatMulNBitsDecode", - "A", - "B", - "scales", - "Y", - config.k, - config.n, - config.bits, - config.block_size, - config.accuracy_level); - - const auto serialized = model.SerializeAsString(); - return std::vector(serialized.begin(), serialized.end()); -} - std::string GetMlpVariantLabel(MlpDecodeBenchmarkVariant variant) { switch (variant) { case MlpDecodeBenchmarkVariant::kUnfused: return "unfused"; case MlpDecodeBenchmarkVariant::kFused: return "fused"; - case MlpDecodeBenchmarkVariant::kSkipNormThenFused: - return "skip_norm_then_fused"; - case MlpDecodeBenchmarkVariant::kSkipNormPassthroughThenFused: - return "skip_norm_passthrough_then_fused"; } return "unknown"; } std::string GetMlpNormKindLabel(MlpNormKind norm_kind) { - return norm_kind == MlpNormKind::kSkipSimplified ? "skip_simplified" : "plain"; + switch (norm_kind) { + case MlpNormKind::kSimplified: + return "simplified"; + case MlpNormKind::kSkipSimplified: + return "skip_simplified"; + case MlpNormKind::kSkipSimplifiedPassthrough: + return "skip_simplified_passthrough"; + } + + return "unknown"; } std::string GetMlpDecodeBenchmarkLabel(MlpDecodeBenchmarkVariant variant, MlpNormKind norm_kind) { std::ostringstream stream; stream << "fp16_mlp_decode_" << GetMlpNormKindLabel(norm_kind) << '_' << GetMlpVariantLabel(variant) << '_' << (IsDecodeBenchmarkPerfMode() ? "perf" : "correctness") << '_' - << (GetDecodeBenchmarkGpu() == DecodeBenchmarkGpu::kRtx5060Ti ? "rtx" : "t") << '_' + << "auto_gpu_" << (IsMatMulNBitsAutoTunerEnabled() ? "tuner_on" : "tuner_off") << '_' << (IsGraphCaptureBenchmarkEnabled() ? "graph_on" : "graph_off"); return stream.str(); @@ -956,14 +813,23 @@ std::string GetQkvVariantLabel(QkvDecodeBenchmarkVariant variant) { } std::string GetQkvNormKindLabel(QkvNormKind norm_kind) { - return norm_kind == QkvNormKind::kSkipSimplified ? "skip_simplified" : "simplified"; + switch (norm_kind) { + case QkvNormKind::kSimplified: + return "simplified"; + case QkvNormKind::kSkipSimplified: + return "skip_simplified"; + case QkvNormKind::kSkipSimplifiedPassthrough: + return "skip_simplified_passthrough"; + } + + return "unknown"; } std::string GetQkvDecodeBenchmarkLabel(QkvDecodeBenchmarkVariant variant, QkvNormKind norm_kind) { std::ostringstream stream; stream << "fp16_qkv_norm_" << GetQkvNormKindLabel(norm_kind) << '_' << GetQkvVariantLabel(variant) << '_' << (IsDecodeBenchmarkPerfMode() ? "perf" : "correctness") << '_' - << (GetDecodeBenchmarkGpu() == DecodeBenchmarkGpu::kRtx5060Ti ? "rtx" : "t") << '_' + << "auto_gpu_" << (IsMatMulNBitsAutoTunerEnabled() ? "tuner_on" : "tuner_off") << '_' << (IsGraphCaptureBenchmarkEnabled() ? "graph_on" : "graph_off"); return stream.str(); @@ -990,25 +856,23 @@ std::vector SerializeMatMulNBitsMlpModel(const MlpDecodeBenchConfig& co case MlpDecodeBenchmarkVariant::kFused: graph->set_name("WebGpuMatMulNBitsMlpDecodeFused"); break; - case MlpDecodeBenchmarkVariant::kSkipNormThenFused: - graph->set_name("WebGpuMatMulNBitsMlpSkipNormThenFused"); - break; - case MlpDecodeBenchmarkVariant::kSkipNormPassthroughThenFused: - graph->set_name("WebGpuMatMulNBitsMlpSkipNormPassthroughThenFused"); - break; case MlpDecodeBenchmarkVariant::kUnfused: default: graph->set_name("WebGpuMatMulNBitsMlpDecodeUnfused"); break; } + const bool has_skip = norm_kind == MlpNormKind::kSkipSimplified || + norm_kind == MlpNormKind::kSkipSimplifiedPassthrough; + const bool has_skip_passthrough = norm_kind == MlpNormKind::kSkipSimplifiedPassthrough; + auto* input = graph->add_input(); input->set_name("A"); input->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.k); - if (norm_kind == MlpNormKind::kSkipSimplified) { + if (has_skip) { auto* skip_input = graph->add_input(); skip_input->set_name("Skip"); skip_input->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); @@ -1021,13 +885,19 @@ std::vector SerializeMatMulNBitsMlpModel(const MlpDecodeBenchConfig& co output->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.n); + if (has_skip_passthrough) { + auto* skip_sum_output = graph->add_output(); + skip_sum_output->set_name("SkipSum"); + skip_sum_output->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + skip_sum_output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + skip_sum_output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.k); + } std::vector gate_b(static_cast(config.n * k_blocks * blob_size), uint8_t{0x11}); std::vector up_b(static_cast(config.n * k_blocks * blob_size), uint8_t{0x77}); std::vector gate_scales(static_cast(config.n * k_blocks), Ort::Float16_t(0.03125f)); std::vector up_scales(static_cast(config.n * k_blocks), Ort::Float16_t(0.0625f)); std::vector norm_scale(static_cast(config.k), Ort::Float16_t(1.0f)); - AddTensorInitializer(*graph, "gate_B", ONNX_NAMESPACE::TensorProto_DataType_UINT8, {config.n, k_blocks, blob_size}, gate_b); AddTensorInitializer(*graph, "up_B", ONNX_NAMESPACE::TensorProto_DataType_UINT8, @@ -1036,104 +906,52 @@ std::vector SerializeMatMulNBitsMlpModel(const MlpDecodeBenchConfig& co {config.n, k_blocks}, gate_scales); AddTensorInitializer(*graph, "up_scales", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {config.n, k_blocks}, up_scales); - if (norm_kind == MlpNormKind::kSkipSimplified) { - AddTensorInitializer(*graph, "norm_scale", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - {config.k}, norm_scale); - } + AddTensorInitializer(*graph, "norm_scale", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + {config.k}, norm_scale); if (variant == MlpDecodeBenchmarkVariant::kFused) { - AddMatMulNBitsSiluMulNode(*graph, - "MatMulNBitsSiluMulDecode", - "A", - norm_kind == MlpNormKind::kSkipSimplified ? "Skip" : "", - norm_kind == MlpNormKind::kSkipSimplified ? "norm_scale" : "", - "gate_B", - "gate_scales", - "up_B", - "up_scales", - "Y", - "", - config.k, - config.n, - config.bits, - config.block_size, - config.accuracy_level); - } else if (variant == MlpDecodeBenchmarkVariant::kSkipNormThenFused) { - ORT_ENFORCE(norm_kind == MlpNormKind::kSkipSimplified, - "SkipNormThenFused benchmark variant requires SkipSimplified norm kind."); + AddMatMulNBitsMlpNode(*graph, + "MatMulNBitsMlpDecode", + "A", + has_skip ? "Skip" : "", + "norm_scale", + "gate_B", + "gate_scales", + "up_B", + "up_scales", + "Y", + has_skip_passthrough ? "SkipSum" : "", + config.k, + config.n, + config.bits, + config.block_size, + config.accuracy_level); + } else { + const char* mlp_input_name = "A_norm"; AddTensorValueInfo(*graph, "A_norm", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.k}); auto* norm = graph->add_node(); - norm->set_name("InputSkipSimplifiedLayerNorm"); - norm->set_op_type("SkipSimplifiedLayerNormalization"); - norm->set_domain("com.microsoft"); - norm->add_input("A"); - norm->add_input("Skip"); - norm->add_input("norm_scale"); - norm->add_output("A_norm"); - auto* attr_epsilon = norm->add_attribute(); - attr_epsilon->set_name("epsilon"); - attr_epsilon->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT); - attr_epsilon->set_f(1e-6f); - - AddMatMulNBitsSiluMulNode(*graph, - "MatMulNBitsSiluMulDecodeAfterSkipNorm", - "A_norm", - "", - "", - "gate_B", - "gate_scales", - "up_B", - "up_scales", - "Y", - "", - config.k, - config.n, - config.bits, - config.block_size, - config.accuracy_level); - } else if (variant == MlpDecodeBenchmarkVariant::kSkipNormPassthroughThenFused) { - ORT_ENFORCE(norm_kind == MlpNormKind::kSkipSimplified, - "SkipNormPassthroughThenFused benchmark variant requires SkipSimplified norm kind."); - - auto* skip_sum_output = graph->add_output(); - skip_sum_output->set_name("SkipOut"); - skip_sum_output->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); - skip_sum_output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); - skip_sum_output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.k); - - AddMatMulNBitsSiluMulNode(*graph, - "MatMulNBitsSiluMulDecodeWithSkipNormPassthrough", - "A", - "Skip", - "norm_scale", - "gate_B", - "gate_scales", - "up_B", - "up_scales", - "Y", - "SkipOut", - config.k, - config.n, - config.bits, - config.block_size, - config.accuracy_level); - } else { - const char* mlp_input_name = norm_kind == MlpNormKind::kSkipSimplified ? "A_norm" : "A"; - if (norm_kind == MlpNormKind::kSkipSimplified) { - AddTensorValueInfo(*graph, "A_norm", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.k}); - auto* norm = graph->add_node(); - norm->set_name("InputSkipSimplifiedLayerNorm"); - norm->set_op_type("SkipSimplifiedLayerNormalization"); + norm->set_name(has_skip ? "InputSkipSimplifiedLayerNorm" : "InputSimplifiedLayerNorm"); + norm->set_op_type(has_skip ? "SkipSimplifiedLayerNormalization" : "SimplifiedLayerNormalization"); + if (has_skip) { norm->set_domain("com.microsoft"); norm->add_input("A"); norm->add_input("Skip"); norm->add_input("norm_scale"); norm->add_output("A_norm"); - auto* attr_epsilon = norm->add_attribute(); - attr_epsilon->set_name("epsilon"); - attr_epsilon->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT); - attr_epsilon->set_f(1e-6f); + if (has_skip_passthrough) { + norm->add_output(""); + norm->add_output(""); + norm->add_output("SkipSum"); + } + } else { + norm->add_input("A"); + norm->add_input("norm_scale"); + norm->add_output("A_norm"); } + auto* attr_epsilon = norm->add_attribute(); + attr_epsilon->set_name("epsilon"); + attr_epsilon->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT); + attr_epsilon->set_f(1e-6f); AddTensorValueInfo(*graph, "gate_mm", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.n}); AddTensorValueInfo(*graph, "up_mm", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.n}); @@ -1209,13 +1027,17 @@ std::vector SerializeMatMulNBitsQkvModel(const QkvDecodeBenchConfig& co ? (norm_kind == QkvNormKind::kSkipSimplified ? "WebGpuMatMulNBitsQkvSkipNormFused" : "WebGpuMatMulNBitsQkvSimplifiedNormFused") : (norm_kind == QkvNormKind::kSkipSimplified ? "WebGpuMatMulNBitsQkvSkipNormUnfused" : "WebGpuMatMulNBitsQkvSimplifiedNormUnfused")); + const bool has_skip = norm_kind == QkvNormKind::kSkipSimplified || + norm_kind == QkvNormKind::kSkipSimplifiedPassthrough; + const bool has_skip_passthrough = norm_kind == QkvNormKind::kSkipSimplifiedPassthrough; + auto* input = graph->add_input(); input->set_name("A"); input->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.k); - if (norm_kind == QkvNormKind::kSkipSimplified) { + if (has_skip) { auto* skip_input = graph->add_input(); skip_input->set_name("Skip"); skip_input->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); @@ -1233,7 +1055,7 @@ std::vector SerializeMatMulNBitsQkvModel(const QkvDecodeBenchConfig& co add_output("Q", config.q_n); add_output("K", config.kv_n); add_output("V", config.kv_n); - if (norm_kind == QkvNormKind::kSkipSimplified) { + if (has_skip_passthrough) { add_output("SkipSum", config.k); } @@ -1254,43 +1076,47 @@ std::vector SerializeMatMulNBitsQkvModel(const QkvDecodeBenchConfig& co AddTensorInitializer(*graph, "v_scales", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {config.kv_n, k_blocks}, v_scales); if (variant == QkvDecodeBenchmarkVariant::kFused) { - AddMatMulNBitsQKVSimplifiedLayerNormNode(*graph, - "MatMulNBitsQKVSimplifiedLayerNormDecode", - "A", - norm_kind == QkvNormKind::kSkipSimplified ? "Skip" : "", - "norm_scale", - "q_B", - "q_scales", - "k_B", - "k_scales", - "v_B", - "v_scales", - "Q", - "K", - "V", - norm_kind == QkvNormKind::kSkipSimplified ? "SkipSum" : "", - config.k, - config.q_n, - config.kv_n, - config.bits, - config.block_size, - config.accuracy_level, - 1e-6f); + AddMatMulNBitsQkvNode(*graph, + "MatMulNBitsQkvDecode", + "A", + has_skip ? "Skip" : "", + "norm_scale", + "q_B", + "q_scales", + "k_B", + "k_scales", + "v_B", + "v_scales", + "Q", + "K", + "V", + has_skip_passthrough ? "SkipSum" : "", + config.k, + config.q_n, + config.kv_n, + config.bits, + config.block_size, + config.accuracy_level, + 1e-6f); } else { AddTensorValueInfo(*graph, "A_norm", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.k}); auto* norm = graph->add_node(); - norm->set_name(norm_kind == QkvNormKind::kSkipSimplified ? "InputSkipSimplifiedLayerNorm" : "InputSimplifiedLayerNorm"); - norm->set_op_type(norm_kind == QkvNormKind::kSkipSimplified ? "SkipSimplifiedLayerNormalization" : "SimplifiedLayerNormalization"); - if (norm_kind == QkvNormKind::kSkipSimplified) { - AddTensorValueInfo(*graph, "SkipSum", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.k}); + norm->set_name(has_skip ? "InputSkipSimplifiedLayerNorm" : "InputSimplifiedLayerNorm"); + norm->set_op_type(has_skip ? "SkipSimplifiedLayerNormalization" : "SimplifiedLayerNormalization"); + if (has_skip) { + if (has_skip_passthrough) { + AddTensorValueInfo(*graph, "SkipSum", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.k}); + } norm->set_domain("com.microsoft"); norm->add_input("A"); norm->add_input("Skip"); norm->add_input("norm_scale"); norm->add_output("A_norm"); - norm->add_output(""); - norm->add_output(""); - norm->add_output("SkipSum"); + if (has_skip_passthrough) { + norm->add_output(""); + norm->add_output(""); + norm->add_output("SkipSum"); + } } else { norm->add_input("A"); norm->add_input("norm_scale"); @@ -1316,6 +1142,16 @@ Ort::Session CreateSessionFromModelData(const std::vector& model_data, Ort::SessionOptions session_options; session_options.DisableMemPattern(); session_options.SetGraphOptimizationLevel(graph_optimization_level); + if (IsVerboseSessionLogEnabled()) { + session_options.SetLogSeverityLevel(0); + } + + const std::string optimized_model_path = GetOptimizedModelPath(); + if (!optimized_model_path.empty()) { + const auto optimized_model_path_ort = onnxruntime::ToWideString(optimized_model_path); + session_options.SetOptimizedModelFilePath(optimized_model_path_ort.c_str()); + } + if (provider_options != nullptr) { if (IsGraphCaptureBenchmarkEnabled()) { session_options.AddConfigEntry(onnxruntime::webgpu::options::kEnableGraphCapture, @@ -1391,7 +1227,8 @@ void ValidateMlpDecodeOutputs(const std::vector& unfused_model_data, const char* const* input_names, const Ort::Value* input_tensors, size_t input_count, - const char* const* output_names) { + const char* const* output_names, + size_t output_count) { Ort::Session unfused_session = CreateSessionFromModelData(unfused_model_data, &provider_options, GraphOptimizationLevel::ORT_DISABLE_ALL); @@ -1399,47 +1236,47 @@ void ValidateMlpDecodeOutputs(const std::vector& unfused_model_data, &provider_options, GraphOptimizationLevel::ORT_ENABLE_ALL); - auto unfused_outputs = unfused_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensors, input_count, output_names, 1); - auto fused_outputs = fused_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensors, input_count, output_names, 1); - - if (unfused_outputs.size() != 1 || fused_outputs.size() != 1) { - throw std::runtime_error("Expected a single output from both unfused and fused MLP sessions."); - } - - const auto& unfused_output = unfused_outputs[0]; - const auto& fused_output = fused_outputs[0]; - const size_t element_count = unfused_output.GetTensorTypeAndShapeInfo().GetElementCount(); - if (element_count != fused_output.GetTensorTypeAndShapeInfo().GetElementCount()) { - throw std::runtime_error("Unfused and fused MLP output sizes do not match."); - } + auto unfused_outputs = unfused_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensors, input_count, output_names, output_count); + auto fused_outputs = fused_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensors, input_count, output_names, output_count); - const auto* unfused_data = unfused_output.GetTensorData(); - const auto* fused_data = fused_output.GetTensorData(); - float max_abs_diff = 0.0f; - size_t max_abs_diff_index = 0; - for (size_t i = 0; i < element_count; ++i) { - const float unfused_value = unfused_data[i].ToFloat(); - const float fused_value = fused_data[i].ToFloat(); - const float abs_diff = std::abs(unfused_value - fused_value); - const float allowed_diff = kDecodeCorrectnessAbsTolerance + - kDecodeCorrectnessRelTolerance * std::max(std::abs(unfused_value), std::abs(fused_value)); - if (abs_diff > max_abs_diff) { - max_abs_diff = abs_diff; - max_abs_diff_index = i; + for (size_t output_index = 0; output_index < output_count; ++output_index) { + const auto& unfused_output = unfused_outputs[output_index]; + const auto& fused_output = fused_outputs[output_index]; + const size_t element_count = unfused_output.GetTensorTypeAndShapeInfo().GetElementCount(); + if (element_count != fused_output.GetTensorTypeAndShapeInfo().GetElementCount()) { + throw std::runtime_error("Unfused and fused MLP output sizes do not match."); } - if (abs_diff > allowed_diff) { - std::ostringstream stream; - stream << "MLP decode correctness check failed at index " << i - << ": unfused=" << unfused_value - << ", fused=" << fused_value - << ", abs_diff=" << abs_diff - << ", allowed_diff=" << allowed_diff; - throw std::runtime_error(stream.str()); + + const auto* unfused_data = unfused_output.GetTensorData(); + const auto* fused_data = fused_output.GetTensorData(); + float max_abs_diff = 0.0f; + size_t max_abs_diff_index = 0; + for (size_t i = 0; i < element_count; ++i) { + const float unfused_value = unfused_data[i].ToFloat(); + const float fused_value = fused_data[i].ToFloat(); + const float abs_diff = std::abs(unfused_value - fused_value); + const float allowed_diff = kDecodeCorrectnessAbsTolerance + + kDecodeCorrectnessRelTolerance * std::max(std::abs(unfused_value), std::abs(fused_value)); + if (abs_diff > max_abs_diff) { + max_abs_diff = abs_diff; + max_abs_diff_index = i; + } + if (abs_diff > allowed_diff) { + std::ostringstream stream; + stream << "MLP decode correctness check failed on output " << output_index + << " at index " << i + << ": unfused=" << unfused_value + << ", fused=" << fused_value + << ", abs_diff=" << abs_diff + << ", allowed_diff=" << allowed_diff; + throw std::runtime_error(stream.str()); + } } - } - std::cout << "MLP decode correctness check passed. max_abs_diff=" << max_abs_diff - << " at index " << max_abs_diff_index << std::endl; + std::cout << "MLP decode correctness check passed for output " << output_index + << ". max_abs_diff=" << max_abs_diff + << " at index " << max_abs_diff_index << std::endl; + } } void ValidateQkvDecodeOutputs(const std::vector& unfused_model_data, @@ -1482,91 +1319,9 @@ void ValidateQkvDecodeOutputs(const std::vector& unfused_model_data, std::cout << "QKV decode correctness check passed." << std::endl; } -[[maybe_unused]] static void BM_WebGpuMatMulNBitsDecode(benchmark::State& state) { - try { - const DecodeBenchConfig config{ - state.range(0), - state.range(1), - state.range(2), - state.range(3), - state.range(4), - }; - - if (config.k % config.block_size != 0) { - state.SkipWithError("K must be divisible by block_size for this benchmark skeleton."); - return; - } - - const DecodeTrafficStats traffic = CalculateDecodeTrafficStats(config); - std::vector model_data = SerializeMatMulNBitsModel(config); - const SelectedWebGpuContext& selected_context = GetSelectedWebGpuContext(); - Ort::Session session = CreateSessionFromModelData(model_data, &selected_context.provider_options); - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); - std::vector input_shape{1, config.k}; - std::vector activation(static_cast(config.k)); - - std::mt19937 rng(123); - std::uniform_real_distribution dist(-1.0f, 1.0f); - for (auto& value : activation) { - value = Ort::Float16_t(dist(rng)); - } - - const char* input_names[] = {"A"}; - const char* output_names[] = {"Y"}; - - auto input_tensor = Ort::Value::CreateTensor(memory_info, - activation.data(), - activation.size(), - input_shape.data(), - input_shape.size()); - Ort::RunOptions run_options = CreateBenchmarkRunOptions(); - - if (!IsDecodeBenchmarkPerfMode()) { - ValidateDecodeOutputs(model_data, session, input_names, &input_tensor, output_names); - } - - // Warm up shader compilation, allocations, and caches before measured iterations. - for (int i = 0; i < kDecodeWarmupRuns; ++i) { - auto warmup_outputs = session.Run(run_options, input_names, &input_tensor, 1, output_names, 1); - benchmark::DoNotOptimize(warmup_outputs); - } - - double total_kernel_seconds = 0.0; - for (auto _ : state) { - const auto kernel_start = std::chrono::steady_clock::now(); - auto outputs = session.Run(run_options, input_names, &input_tensor, 1, output_names, 1); - const auto kernel_end = std::chrono::steady_clock::now(); - total_kernel_seconds += std::chrono::duration(kernel_end - kernel_start).count(); - benchmark::DoNotOptimize(outputs); - } - - const double total_flops = 2.0 * static_cast(config.n) * static_cast(config.k); - const double achieved_bandwidth_bytes_per_second = - total_kernel_seconds > 0.0 - ? traffic.total_bytes * static_cast(state.iterations()) / total_kernel_seconds - : 0.0; - const double achieved_bandwidth_gbps = achieved_bandwidth_bytes_per_second / 1.0e9; - const double rtx_5060_ti_utilization_pct = - achieved_bandwidth_bytes_per_second / kRtx5060TiTheoreticalBandwidthBytesPerSecond * 100.0; - - state.SetLabel(GetDecodeBenchmarkLabel()); - state.counters["TFLOPS"] = benchmark::Counter( - total_flops, - benchmark::Counter::kIsIterationInvariantRate); - state.counters["MemBW_GBps"] = benchmark::Counter(achieved_bandwidth_gbps); - state.counters["BWUtil_5060Ti_pct"] = benchmark::Counter(rtx_5060_ti_utilization_pct); - state.counters["Traffic_MB"] = benchmark::Counter(traffic.total_bytes / 1.0e6); - state.counters["Input_MB"] = benchmark::Counter(traffic.input_bytes / 1.0e6); - state.counters["PackedW_MB"] = benchmark::Counter(traffic.packed_weight_bytes / 1.0e6); - state.counters["Scales_MB"] = benchmark::Counter(traffic.scale_bytes / 1.0e6); - state.counters["Output_MB"] = benchmark::Counter(traffic.output_bytes / 1.0e6); - state.counters["GraphReplay"] = benchmark::Counter(IsGraphCaptureBenchmarkEnabled() ? 1.0 : 0.0); - } catch (const std::exception& ex) { - state.SkipWithError(ex.what()); - } -} - -void BenchmarkWebGpuMatMulNBitsQkvDecode(benchmark::State& state, QkvDecodeBenchmarkVariant variant, QkvNormKind norm_kind) { +void BenchmarkWebGpuMatMulNBitsQkvDecode(benchmark::State& state, + QkvDecodeBenchmarkVariant variant, + QkvNormKind norm_kind) { try { const QkvDecodeBenchConfig config{ state.range(0), @@ -1602,14 +1357,19 @@ void BenchmarkWebGpuMatMulNBitsQkvDecode(benchmark::State& state, QkvDecodeBench value = Ort::Float16_t(dist(rng)); } + const bool has_skip = norm_kind == QkvNormKind::kSkipSimplified || + norm_kind == QkvNormKind::kSkipSimplifiedPassthrough; + const bool has_skip_passthrough = norm_kind == QkvNormKind::kSkipSimplifiedPassthrough; const char* simplified_input_names[] = {"A"}; const char* skip_input_names[] = {"A", "Skip"}; const char* simplified_output_names[] = {"Q", "K", "V"}; - const char* skip_output_names[] = {"Q", "K", "V", "SkipSum"}; - const char* const* input_names = norm_kind == QkvNormKind::kSkipSimplified ? skip_input_names : simplified_input_names; - const char* const* output_names = norm_kind == QkvNormKind::kSkipSimplified ? skip_output_names : simplified_output_names; - const size_t input_count = norm_kind == QkvNormKind::kSkipSimplified ? 2u : 1u; - const size_t output_count = norm_kind == QkvNormKind::kSkipSimplified ? 4u : 3u; + const char* skip_output_names[] = {"Q", "K", "V"}; + const char* skip_passthrough_output_names[] = {"Q", "K", "V", "SkipSum"}; + const char* const* input_names = has_skip ? skip_input_names : simplified_input_names; + const char* const* output_names = has_skip_passthrough ? skip_passthrough_output_names + : (has_skip ? skip_output_names : simplified_output_names); + const size_t input_count = has_skip ? 2u : 1u; + const size_t output_count = has_skip_passthrough ? 4u : 3u; std::array input_tensors = { Ort::Value::CreateTensor(memory_info, @@ -1693,16 +1453,15 @@ void BenchmarkWebGpuMatMulNBitsMlpDecode(benchmark::State& state, std::vector model_data = SerializeMatMulNBitsMlpModel(config, variant, norm_kind); const SelectedWebGpuContext& selected_context = GetSelectedWebGpuContext(); const GraphOptimizationLevel optimization_level = - variant == MlpDecodeBenchmarkVariant::kUnfused ? GraphOptimizationLevel::ORT_DISABLE_ALL - : GraphOptimizationLevel::ORT_ENABLE_ALL; + variant == MlpDecodeBenchmarkVariant::kUnfused ? GraphOptimizationLevel::ORT_DISABLE_ALL + : GraphOptimizationLevel::ORT_ENABLE_ALL; Ort::Session session = CreateSessionFromModelData(model_data, - &selected_context.provider_options, - optimization_level); + &selected_context.provider_options, + optimization_level); Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); std::vector input_shape{1, config.k}; std::vector activation(static_cast(config.k)); std::vector skip_activation(static_cast(config.k)); - std::mt19937 rng(123); std::uniform_real_distribution dist(-1.0f, 1.0f); for (auto& value : activation) { @@ -1712,12 +1471,17 @@ void BenchmarkWebGpuMatMulNBitsMlpDecode(benchmark::State& state, value = Ort::Float16_t(dist(rng)); } - const char* plain_input_names[] = {"A"}; + const bool has_skip = norm_kind == MlpNormKind::kSkipSimplified || + norm_kind == MlpNormKind::kSkipSimplifiedPassthrough; + const bool has_skip_passthrough = norm_kind == MlpNormKind::kSkipSimplifiedPassthrough; + const char* simplified_input_names[] = {"A"}; const char* skip_input_names[] = {"A", "Skip"}; - const char* const* input_names = norm_kind == MlpNormKind::kSkipSimplified ? skip_input_names : plain_input_names; - const char* output_names[] = {"Y"}; - const size_t input_count = norm_kind == MlpNormKind::kSkipSimplified ? 2u : 1u; - + const char* main_output_names[] = {"Y"}; + const char* skip_passthrough_output_names[] = {"Y", "SkipSum"}; + const char* const* input_names = has_skip ? skip_input_names : simplified_input_names; + const char* const* output_names = has_skip_passthrough ? skip_passthrough_output_names : main_output_names; + const size_t input_count = has_skip ? 2u : 1u; + const size_t output_count = has_skip_passthrough ? 2u : 1u; std::array input_tensors = { Ort::Value::CreateTensor(memory_info, activation.data(), @@ -1738,18 +1502,19 @@ void BenchmarkWebGpuMatMulNBitsMlpDecode(benchmark::State& state, input_names, input_tensors.data(), input_count, - output_names); + output_names, + output_count); } for (int i = 0; i < kDecodeWarmupRuns; ++i) { - auto warmup_outputs = session.Run(run_options, input_names, input_tensors.data(), input_count, output_names, 1); + auto warmup_outputs = session.Run(run_options, input_names, input_tensors.data(), input_count, output_names, output_count); benchmark::DoNotOptimize(warmup_outputs); } double total_kernel_seconds = 0.0; for (auto _ : state) { const auto kernel_start = std::chrono::steady_clock::now(); - auto outputs = session.Run(run_options, input_names, input_tensors.data(), input_count, output_names, 1); + auto outputs = session.Run(run_options, input_names, input_tensors.data(), input_count, output_names, output_count); const auto kernel_end = std::chrono::steady_clock::now(); total_kernel_seconds += std::chrono::duration(kernel_end - kernel_start).count(); benchmark::DoNotOptimize(outputs); @@ -1768,8 +1533,6 @@ void BenchmarkWebGpuMatMulNBitsMlpDecode(benchmark::State& state, state.counters["ApproxMemBW_GBps"] = benchmark::Counter(achieved_bandwidth_bytes_per_second / 1.0e9); state.counters["ApproxTraffic_MB"] = benchmark::Counter(traffic.total_bytes / 1.0e6); state.counters["Input_MB"] = benchmark::Counter(traffic.input_bytes / 1.0e6); - state.counters["SkipInput_MB"] = benchmark::Counter(traffic.skip_input_bytes / 1.0e6); - state.counters["NormScale_MB"] = benchmark::Counter(traffic.norm_scale_bytes / 1.0e6); state.counters["PackedW_MB"] = benchmark::Counter(traffic.packed_weight_bytes / 1.0e6); state.counters["Scales_MB"] = benchmark::Counter(traffic.scale_bytes / 1.0e6); state.counters["Intermediate_MB"] = benchmark::Counter(traffic.intermediate_bytes / 1.0e6); @@ -1780,50 +1543,52 @@ void BenchmarkWebGpuMatMulNBitsMlpDecode(benchmark::State& state, } } -static void BM_WebGpuMatMulNBitsMlpDecodeUnfused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kUnfused, MlpNormKind::kNone); +static void BM_WebGpuMatMulNBitsMlpSimplifiedDecodeUnfused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kUnfused, MlpNormKind::kSimplified); } -static void BM_WebGpuMatMulNBitsMlpDecodeFused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kFused, MlpNormKind::kNone); +static void BM_WebGpuMatMulNBitsMlpSimplifiedDecodeFused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kFused, MlpNormKind::kSimplified); } -static void BM_WebGpuMatMulNBitsMlpSkipDecodeUnfused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kUnfused, MlpNormKind::kSkipSimplified); +static void BM_WebGpuMatMulNBitsQkvSimplifiedNormDecodeUnfused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kUnfused, QkvNormKind::kSimplified); } -static void BM_WebGpuMatMulNBitsMlpSkipDecodeFused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kFused, MlpNormKind::kSkipSimplified); +static void BM_WebGpuMatMulNBitsQkvSimplifiedNormDecodeFused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kFused, QkvNormKind::kSimplified); } -static void BM_WebGpuMatMulNBitsMlpSkipDecodeSkipNormThenFused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kSkipNormThenFused, MlpNormKind::kSkipSimplified); +static void BM_WebGpuMatMulNBitsQkvSkipDecodeUnfused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kUnfused, QkvNormKind::kSkipSimplified); } -static void BM_WebGpuMatMulNBitsMlpSkipDecodeSkipNormPassthroughThenFused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kSkipNormPassthroughThenFused, MlpNormKind::kSkipSimplified); +static void BM_WebGpuMatMulNBitsQkvSkipDecodeFused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kFused, QkvNormKind::kSkipSimplified); } -static void BM_WebGpuMatMulNBitsQkvDecodeUnfused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kUnfused, QkvNormKind::kSimplified); +static void BM_WebGpuMatMulNBitsQkvSkipPassthroughDecodeUnfused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kUnfused, QkvNormKind::kSkipSimplifiedPassthrough); } -static void BM_WebGpuMatMulNBitsQkvDecodeFused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kFused, QkvNormKind::kSimplified); +static void BM_WebGpuMatMulNBitsQkvSkipPassthroughDecodeFused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kFused, QkvNormKind::kSkipSimplifiedPassthrough); } -static void BM_WebGpuMatMulNBitsQkvSkipDecodeUnfused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kUnfused, QkvNormKind::kSkipSimplified); +static void BM_WebGpuMatMulNBitsMlpSkipDecodeUnfused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kUnfused, MlpNormKind::kSkipSimplified); } -static void BM_WebGpuMatMulNBitsQkvSkipDecodeFused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kFused, QkvNormKind::kSkipSimplified); +static void BM_WebGpuMatMulNBitsMlpSkipDecodeFused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kFused, MlpNormKind::kSkipSimplified); } -[[maybe_unused]] void ApplyWebGpuMatMulNBitsDecodeArgs(benchmark::internal::Benchmark* benchmark) { - for (const auto& config : GetDecodeBenchConfigs()) { - benchmark->Args({config.n, config.k, config.bits, config.block_size, config.accuracy_level}); - } +static void BM_WebGpuMatMulNBitsMlpSkipPassthroughDecodeUnfused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kUnfused, MlpNormKind::kSkipSimplifiedPassthrough); +} + +static void BM_WebGpuMatMulNBitsMlpSkipPassthroughDecodeFused(benchmark::State& state) { + BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kFused, MlpNormKind::kSkipSimplifiedPassthrough); } void ApplyWebGpuMatMulNBitsMlpDecodeArgs(benchmark::internal::Benchmark* benchmark) { @@ -1838,71 +1603,79 @@ void ApplyWebGpuMatMulNBitsQkvDecodeArgs(benchmark::internal::Benchmark* benchma } } -// BENCHMARK(BM_WebGpuMatMulNBitsDecode) -// ->Apply(ApplyWebGpuMatMulNBitsDecodeArgs) -// ->Repetitions(5) -// ->ReportAggregatesOnly() -// ->UseRealTime() -// ->Unit(benchmark::TimeUnit::kMicrosecond); - -BENCHMARK(BM_WebGpuMatMulNBitsMlpDecodeUnfused) - ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); - -BENCHMARK(BM_WebGpuMatMulNBitsMlpDecodeFused) - ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); +// Qkv benchmarks +BENCHMARK(BM_WebGpuMatMulNBitsQkvSimplifiedNormDecodeUnfused) + ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); -BENCHMARK(BM_WebGpuMatMulNBitsMlpSkipDecodeUnfused) - ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); - -BENCHMARK(BM_WebGpuMatMulNBitsMlpSkipDecodeFused) - ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); - -BENCHMARK(BM_WebGpuMatMulNBitsMlpSkipDecodeSkipNormThenFused) - ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); - -BENCHMARK(BM_WebGpuMatMulNBitsMlpSkipDecodeSkipNormPassthroughThenFused) - ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); - -BENCHMARK(BM_WebGpuMatMulNBitsQkvDecodeUnfused) - ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); - -BENCHMARK(BM_WebGpuMatMulNBitsQkvDecodeFused) - ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); +BENCHMARK(BM_WebGpuMatMulNBitsQkvSimplifiedNormDecodeFused) + ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); BENCHMARK(BM_WebGpuMatMulNBitsQkvSkipDecodeUnfused) - ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); + ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); BENCHMARK(BM_WebGpuMatMulNBitsQkvSkipDecodeFused) - ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); + ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +BENCHMARK(BM_WebGpuMatMulNBitsQkvSkipPassthroughDecodeUnfused) + ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +BENCHMARK(BM_WebGpuMatMulNBitsQkvSkipPassthroughDecodeFused) + ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +// Mlp benchmarks +BENCHMARK(BM_WebGpuMatMulNBitsMlpSimplifiedDecodeUnfused) + ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +BENCHMARK(BM_WebGpuMatMulNBitsMlpSimplifiedDecodeFused) + ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +BENCHMARK(BM_WebGpuMatMulNBitsMlpSkipDecodeUnfused) + ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +BENCHMARK(BM_WebGpuMatMulNBitsMlpSkipDecodeFused) + ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +BENCHMARK(BM_WebGpuMatMulNBitsMlpSkipPassthroughDecodeUnfused) + ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +BENCHMARK(BM_WebGpuMatMulNBitsMlpSkipPassthroughDecodeFused) + ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) + ->ReportAggregatesOnly() + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); } // namespace + diff --git a/onnxruntime/test/optimizer/graph_transform_utils_test.cc b/onnxruntime/test/optimizer/graph_transform_utils_test.cc index 75b4c57c670a5..302768b9fbdc7 100644 --- a/onnxruntime/test/optimizer/graph_transform_utils_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_utils_test.cc @@ -4,7 +4,6 @@ #include "core/common/inlined_containers.h" #include "core/graph/onnx_protobuf.h" #include "test/unittest_util/framework_test_utils.h" -#include "test/util/include/scoped_env_vars.h" #include "test/capturing_sink.h" #include "test/test_environment.h" #include "gtest/gtest.h" @@ -17,21 +16,6 @@ using namespace ONNX_NAMESPACE; namespace onnxruntime { namespace test { -namespace { - -constexpr const char* kOrtEnableMatMulNBitsSiluFusionEnvVar = "ORT_ENABLE_MATMUL_NBITS_SILU_FUSION"; -constexpr const char* kOrtEnableMatMulNBitsQKVSimplifiedLayerNormFusionEnvVar = - "ORT_ENABLE_MATMUL_NBITS_QKV_SIMPLIFIED_LAYER_NORM_FUSION"; - -bool HasTransformerNamed(const InlinedVector>& transformers, - std::string_view name) { - return std::any_of(transformers.begin(), transformers.end(), [&](const auto& transformer) { - return transformer && transformer->Name() == name; - }); -} - -} // namespace - TEST(GraphTransformerUtilsTests, TestGenerateRewriterules) { // Generate all test auto rewrite_rules = optimizer_utils::GenerateRewriteRules(TransformerLevel::Level1); @@ -87,66 +71,6 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { #endif } -TEST(GraphTransformerUtilsTests, MatMulNBitsSiluFusionDisabledByDefault) { -#if defined(DISABLE_CONTRIB_OPS) - GTEST_SKIP() << "MatMulNBitsSiluFusion requires contrib ops."; -#else - const EnvVarMap env_vars{{kOrtEnableMatMulNBitsSiluFusionEnvVar, optional{}}}; - ScopedEnvironmentVariables scoped_env_vars{env_vars}; - - CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); - const auto& logger = DefaultLoggingManager().DefaultLogger(); - auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger); - - EXPECT_FALSE(HasTransformerNamed(transformers, "MatMulNBitsSiluFusion")); -#endif -} - -TEST(GraphTransformerUtilsTests, MatMulNBitsSiluFusionEnabledViaEnvironmentVariable) { -#if defined(DISABLE_CONTRIB_OPS) - GTEST_SKIP() << "MatMulNBitsSiluFusion requires contrib ops."; -#else - const EnvVarMap env_vars{{kOrtEnableMatMulNBitsSiluFusionEnvVar, optional{"1"}}}; - ScopedEnvironmentVariables scoped_env_vars{env_vars}; - - CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); - const auto& logger = DefaultLoggingManager().DefaultLogger(); - auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger); - - EXPECT_TRUE(HasTransformerNamed(transformers, "MatMulNBitsSiluFusion")); -#endif -} - -TEST(GraphTransformerUtilsTests, MatMulNBitsQKVSimplifiedLayerNormFusionDisabledByDefault) { -#if defined(DISABLE_CONTRIB_OPS) - GTEST_SKIP() << "MatMulNBitsQKVSimplifiedLayerNormFusion requires contrib ops."; -#else - const EnvVarMap env_vars{{kOrtEnableMatMulNBitsQKVSimplifiedLayerNormFusionEnvVar, optional{}}}; - ScopedEnvironmentVariables scoped_env_vars{env_vars}; - - CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); - const auto& logger = DefaultLoggingManager().DefaultLogger(); - auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger); - - EXPECT_FALSE(HasTransformerNamed(transformers, "MatMulNBitsQKVSimplifiedLayerNormFusion")); -#endif -} - -TEST(GraphTransformerUtilsTests, MatMulNBitsQKVSimplifiedLayerNormFusionEnabledViaEnvironmentVariable) { -#if defined(DISABLE_CONTRIB_OPS) - GTEST_SKIP() << "MatMulNBitsQKVSimplifiedLayerNormFusion requires contrib ops."; -#else - const EnvVarMap env_vars{{kOrtEnableMatMulNBitsQKVSimplifiedLayerNormFusionEnvVar, optional{"1"}}}; - ScopedEnvironmentVariables scoped_env_vars{env_vars}; - - CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); - const auto& logger = DefaultLoggingManager().DefaultLogger(); - auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger); - - EXPECT_TRUE(HasTransformerNamed(transformers, "MatMulNBitsQKVSimplifiedLayerNormFusion")); -#endif -} - TEST(GraphTransformerUtilsTests, TestDQMatMulNBitsFusionConfigWithContribGating) { SessionOptions session_options; const auto status = session_options.config_options.AddConfigEntry( diff --git a/onnxruntime/test/optimizer/matmul_nbits_silu_fusion_test.cc b/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc similarity index 62% rename from onnxruntime/test/optimizer/matmul_nbits_silu_fusion_test.cc rename to onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc index 2a1886caa618e..4449f989e9e55 100644 --- a/onnxruntime/test/optimizer/matmul_nbits_silu_fusion_test.cc +++ b/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc @@ -1,10 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/graph/graph_utils.h" #include "core/graph/node_attr_utils.h" #include "core/optimizer/graph_transformer_mgr.h" -#include "core/optimizer/matmul_nbits_silu_fusion.h" +#include "core/optimizer/matmul_nbits_mlp_fusion.h" #include "core/optimizer/utils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "test/util/include/asserts.h" #include "test/util/include/default_providers.h" @@ -21,8 +23,9 @@ namespace test { namespace { +constexpr const char* kExpectedActivation = "silu"; + enum class NormAnchorKind { - kNone, kSimplified, kSkipSimplified, }; @@ -46,19 +49,19 @@ NodeAttributes MakeMatMulNBitsAttrs(int64_t k, int64_t n, int64_t block_size, in return attrs; } -Status CheckMatMulNBitsSiluFusedGraphImpl(const Graph& graph, NormAnchorKind norm_anchor_kind) { +Status CheckMatMulNBitsMlpFusedGraphImpl(const Graph& graph, NormAnchorKind norm_anchor_kind) { const auto op_to_count = CountOpsInGraph(graph); - if (OpCount(op_to_count, "com.microsoft.MatMulNBitsSiluMul") != 1 || + if (OpCount(op_to_count, "com.microsoft.MatMulNBitsMlp") != 1 || OpCount(op_to_count, "com.microsoft.MatMulNBits") != 0 || OpCount(op_to_count, "SimplifiedLayerNormalization") != 0 || OpCount(op_to_count, "com.microsoft.SkipSimplifiedLayerNormalization") != 0 || OpCount(op_to_count, "Sigmoid") != 0 || OpCount(op_to_count, "Mul") != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected operator counts after MatMulNBitsSiluFusion."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected operator counts after MatMulNBitsMlpFusion."); } for (const auto& node : graph.Nodes()) { - if (node.OpType() == "MatMulNBitsSiluMul") { + if (node.OpType() == "MatMulNBitsMlp") { ORT_RETURN_IF_NOT(node.Domain() == kMSDomain, "Fused node must be in com.microsoft domain."); ORT_RETURN_IF_NOT(node.GetExecutionProviderType() == kWebGpuExecutionProvider, "Fused node must be assigned to WebGPU EP."); @@ -66,42 +69,44 @@ Status CheckMatMulNBitsSiluFusedGraphImpl(const Graph& graph, NormAnchorKind nor const bool has_skip = node.InputDefs()[1] != nullptr && !node.InputDefs()[1]->Name().empty(); const bool has_norm_scale = node.InputDefs()[2] != nullptr && !node.InputDefs()[2]->Name().empty(); ORT_RETURN_IF_NOT(has_skip == (norm_anchor_kind == NormAnchorKind::kSkipSimplified), - "Unexpected skip input presence on fused node."); - ORT_RETURN_IF_NOT(has_norm_scale == (norm_anchor_kind != NormAnchorKind::kNone), - "Unexpected norm_scale input presence on fused node."); + "Unexpected skip input presence on fused node."); + ORT_RETURN_IF_NOT(has_norm_scale, + "Expected norm_scale input on fused node."); + ORT_RETURN_IF_NOT(node.OutputDefs().size() == 1u, + "Non-passthrough fusion should expose only the Y output."); + + const auto* activation_attr = graph_utils::GetNodeAttribute(node, "activation"); + ORT_RETURN_IF_NOT(activation_attr != nullptr && activation_attr->s() == kExpectedActivation, + "Fused node must carry activation='silu'."); } } return Status::OK(); } -Status CheckMatMulNBitsSiluFusedGraph(const Graph& graph) { - return CheckMatMulNBitsSiluFusedGraphImpl(graph, NormAnchorKind::kNone); +Status CheckMatMulNBitsMlpSimplifiedFusedGraph(const Graph& graph) { + return CheckMatMulNBitsMlpFusedGraphImpl(graph, NormAnchorKind::kSimplified); } -Status CheckMatMulNBitsSiluSimplifiedFusedGraph(const Graph& graph) { - return CheckMatMulNBitsSiluFusedGraphImpl(graph, NormAnchorKind::kSimplified); +Status CheckMatMulNBitsMlpSkipFusedGraph(const Graph& graph) { + return CheckMatMulNBitsMlpFusedGraphImpl(graph, NormAnchorKind::kSkipSimplified); } -Status CheckMatMulNBitsSiluSkipFusedGraph(const Graph& graph) { - return CheckMatMulNBitsSiluFusedGraphImpl(graph, NormAnchorKind::kSkipSimplified); -} - -Status CheckMatMulNBitsSiluSkipOutputPassthroughFusedGraph(const Graph& graph) { +Status CheckMatMulNBitsMlpSkipOutputPassthroughFusedGraph(const Graph& graph) { const auto op_to_count = CountOpsInGraph(graph); - if (OpCount(op_to_count, "com.microsoft.MatMulNBitsSiluMul") != 1 || + if (OpCount(op_to_count, "com.microsoft.MatMulNBitsMlp") != 1 || OpCount(op_to_count, "com.microsoft.MatMulNBits") != 0 || OpCount(op_to_count, "SimplifiedLayerNormalization") != 0 || OpCount(op_to_count, "com.microsoft.SkipSimplifiedLayerNormalization") != 0 || OpCount(op_to_count, "Sigmoid") != 0 || OpCount(op_to_count, "Mul") != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Unexpected operator counts after MatMulNBitsSiluFusion with skip output passthrough."); + "Unexpected operator counts after MatMulNBitsMlpFusion with skip output passthrough."); } bool found_fused_node = false; for (const auto& node : graph.Nodes()) { - if (node.OpType() != "MatMulNBitsSiluMul") { + if (node.OpType() != "MatMulNBitsMlp") { continue; } @@ -111,22 +116,26 @@ Status CheckMatMulNBitsSiluSkipOutputPassthroughFusedGraph(const Graph& graph) { "Fused node must be assigned to WebGPU EP."); ORT_RETURN_IF_NOT(node.InputDefs().size() == 9u, "Fused node must have 9 inputs."); ORT_RETURN_IF_NOT(node.OutputDefs().size() == 2u, - "Fused node must expose Y and the passthrough residual output."); + "Fused node must expose Y and the passthrough residual output."); const bool has_skip = node.InputDefs()[1] != nullptr && !node.InputDefs()[1]->Name().empty(); const bool has_norm_scale = node.InputDefs()[2] != nullptr && !node.InputDefs()[2]->Name().empty(); ORT_RETURN_IF_NOT(has_skip && has_norm_scale, - "Skip output passthrough should remain fused into MatMulNBitsSiluMul."); + "Skip output passthrough should remain fused into MatMulNBitsMlp."); ORT_RETURN_IF_NOT(node.OutputDefs()[1] != nullptr && !node.OutputDefs()[1]->Name().empty(), - "Expected fused node to preserve the residual passthrough output."); + "Expected fused node to preserve the residual passthrough output."); + + const auto* activation_attr = graph_utils::GetNodeAttribute(node, "activation"); + ORT_RETURN_IF_NOT(activation_attr != nullptr && activation_attr->s() == kExpectedActivation, + "Fused node must carry activation='silu'."); } - ORT_RETURN_IF_NOT(found_fused_node, "Expected a MatMulNBitsSiluMul node in the transformed graph."); + ORT_RETURN_IF_NOT(found_fused_node, "Expected a MatMulNBitsMlp node in the transformed graph."); return Status::OK(); } -void BuildMatMulNBitsSiluWebGpuPatternImpl(ModelTestBuilder& builder, - NormAnchorKind norm_anchor_kind, - SkipOutputKind skip_output_kind = SkipOutputKind::kNone) { +void BuildMatMulNBitsMlpWebGpuPatternImpl(ModelTestBuilder& builder, + NormAnchorKind norm_anchor_kind, + SkipOutputKind skip_output_kind = SkipOutputKind::kNone) { constexpr int64_t k = 16; constexpr int64_t n = 8; constexpr int64_t block_size = 16; @@ -150,9 +159,7 @@ void BuildMatMulNBitsSiluWebGpuPatternImpl(ModelTestBuilder& builder, NodeArg* up_scale = builder.MakeInitializer({n, 1}, MLFloat16(1.0f), MLFloat16(1.0f)); NodeArg* up_bias = builder.MakeInitializer({n}, MLFloat16(0.0f), MLFloat16(0.0f)); - NodeArg* normalized_input = norm_anchor_kind == NormAnchorKind::kNone - ? input - : builder.MakeIntermediate(std::vector{1, k}); + NodeArg* normalized_input = builder.MakeIntermediate(std::vector{1, k}); NodeArg* gate_out = builder.MakeIntermediate(std::vector{1, n}); NodeArg* up_out = builder.MakeIntermediate(std::vector{1, n}); NodeArg* sigmoid_out = builder.MakeIntermediate(std::vector{1, n}); @@ -176,8 +183,8 @@ void BuildMatMulNBitsSiluWebGpuPatternImpl(ModelTestBuilder& builder, norm_outputs.push_back(residual_output); } norm = &builder.AddNode("SkipSimplifiedLayerNormalization", {input, skip_input, norm_scale}, norm_outputs, - kMSDomain); - } else if (norm_anchor_kind == NormAnchorKind::kSimplified) { + kMSDomain); + } else { NodeArg* norm_scale = builder.MakeInitializer({k}, MLFloat16(1.0f), MLFloat16(1.0f)); norm = &builder.AddNode("SimplifiedLayerNormalization", {input, norm_scale}, {normalized_input}); } @@ -204,164 +211,129 @@ void BuildMatMulNBitsSiluWebGpuPatternImpl(ModelTestBuilder& builder, SetWebGpuProvider(final_mul); } -void BuildMatMulNBitsSiluWebGpuPattern(ModelTestBuilder& builder) { - BuildMatMulNBitsSiluWebGpuPatternImpl(builder, NormAnchorKind::kNone); +void BuildMatMulNBitsMlpSimplifiedWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsMlpWebGpuPatternImpl(builder, NormAnchorKind::kSimplified); } -void BuildMatMulNBitsSiluSimplifiedWebGpuPattern(ModelTestBuilder& builder) { - BuildMatMulNBitsSiluWebGpuPatternImpl(builder, NormAnchorKind::kSimplified); +void BuildMatMulNBitsMlpSkipWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsMlpWebGpuPatternImpl(builder, NormAnchorKind::kSkipSimplified); } -void BuildMatMulNBitsSiluSkipWebGpuPattern(ModelTestBuilder& builder) { - BuildMatMulNBitsSiluWebGpuPatternImpl(builder, NormAnchorKind::kSkipSimplified); -} - -void BuildMatMulNBitsSiluSkipOutputPassthroughWebGpuPattern(ModelTestBuilder& builder) { - BuildMatMulNBitsSiluWebGpuPatternImpl(builder, NormAnchorKind::kSkipSimplified, SkipOutputKind::kGraphOutput); +void BuildMatMulNBitsMlpSkipOutputPassthroughWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsMlpWebGpuPatternImpl(builder, NormAnchorKind::kSkipSimplified, SkipOutputKind::kGraphOutput); } } // namespace -TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionFusesWebGpuPattern) { +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionFusesSimplifiedWebGpuPattern) { ASSERT_STATUS_OK(TestGraphTransformer( - BuildMatMulNBitsSiluWebGpuPattern, + BuildMatMulNBitsMlpSimplifiedWebGpuPattern, 21, *logger_, - std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), TransformerLevel::Level2, 1, nullptr, - CheckMatMulNBitsSiluFusedGraph)); + CheckMatMulNBitsMlpSimplifiedFusedGraph)); } -TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionFusesSkipWebGpuPattern) { +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionFusesSkipWebGpuPattern) { ASSERT_STATUS_OK(TestGraphTransformer( - BuildMatMulNBitsSiluSkipWebGpuPattern, + BuildMatMulNBitsMlpSkipWebGpuPattern, 21, *logger_, - std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), TransformerLevel::Level2, 1, nullptr, - CheckMatMulNBitsSiluSkipFusedGraph)); + CheckMatMulNBitsMlpSkipFusedGraph)); } -TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionFusesSkipWebGpuPatternWithResidualOutputPassthrough) { +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionFusesSkipWebGpuPatternWithResidualOutputPassthrough) { ASSERT_STATUS_OK(TestGraphTransformer( - BuildMatMulNBitsSiluSkipOutputPassthroughWebGpuPattern, + BuildMatMulNBitsMlpSkipOutputPassthroughWebGpuPattern, 21, *logger_, - std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), TransformerLevel::Level2, 1, nullptr, - CheckMatMulNBitsSiluSkipOutputPassthroughFusedGraph)); + CheckMatMulNBitsMlpSkipOutputPassthroughFusedGraph)); } -TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionFusesSimplifiedWebGpuPattern) { - ASSERT_STATUS_OK(TestGraphTransformer( - BuildMatMulNBitsSiluSimplifiedWebGpuPattern, - 21, - *logger_, - std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), - TransformerLevel::Level2, - 1, - nullptr, - CheckMatMulNBitsSiluSimplifiedFusedGraph)); -} - -TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionMatchesUnfusedWebGpuResults) { +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSimplifiedWebGpuResults) { auto webgpu_ep = DefaultWebGpuExecutionProvider(); if (!webgpu_ep) { GTEST_SKIP() << "WebGPU EP unavailable in this build."; } auto check_transformed_graph = [](InferenceSessionWrapper& session) { - ASSERT_STATUS_OK(CheckMatMulNBitsSiluFusedGraph(session.GetGraph())); + ASSERT_STATUS_OK(CheckMatMulNBitsMlpSimplifiedFusedGraph(session.GetGraph())); }; TransformerTester( - BuildMatMulNBitsSiluWebGpuPattern, + BuildMatMulNBitsMlpSimplifiedWebGpuPattern, check_transformed_graph, TransformerLevel::Level1, TransformerLevel::Level2, 21, 1e-3, 1e-3, - std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), {}, {}, std::move(webgpu_ep)); } -TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionMatchesUnfusedSkipWebGpuResults) { +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuResults) { auto webgpu_ep = DefaultWebGpuExecutionProvider(); if (!webgpu_ep) { GTEST_SKIP() << "WebGPU EP unavailable in this build."; } auto check_transformed_graph = [](InferenceSessionWrapper& session) { - ASSERT_STATUS_OK(CheckMatMulNBitsSiluSkipFusedGraph(session.GetGraph())); + ASSERT_STATUS_OK(CheckMatMulNBitsMlpSkipFusedGraph(session.GetGraph())); }; TransformerTester( - BuildMatMulNBitsSiluSkipWebGpuPattern, + BuildMatMulNBitsMlpSkipWebGpuPattern, check_transformed_graph, TransformerLevel::Level1, TransformerLevel::Level2, 21, 1e-3, 1e-3, - std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), {}, {}, std::move(webgpu_ep)); } -TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionMatchesUnfusedSkipWebGpuResultsWithResidualOutputPassthrough) { +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuResultsWithResidualOutputPassthrough) { auto webgpu_ep = DefaultWebGpuExecutionProvider(); if (!webgpu_ep) { GTEST_SKIP() << "WebGPU EP unavailable in this build."; } - auto check_transformed_graph = [](InferenceSessionWrapper& session) { - ASSERT_STATUS_OK(CheckMatMulNBitsSiluSkipOutputPassthroughFusedGraph(session.GetGraph())); + auto add_session_options = [](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableSpecifiedOptimizers, + "EliminateIdentity")); }; - TransformerTester( - BuildMatMulNBitsSiluSkipOutputPassthroughWebGpuPattern, - check_transformed_graph, - TransformerLevel::Level1, - TransformerLevel::Level2, - 21, - 1e-3, - 1e-3, - std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), - {}, - {}, - std::move(webgpu_ep)); -} - -TEST_F(GraphTransformationTests, MatMulNBitsSiluFusionMatchesUnfusedSimplifiedWebGpuResults) { - auto webgpu_ep = DefaultWebGpuExecutionProvider(); - if (!webgpu_ep) { - GTEST_SKIP() << "WebGPU EP unavailable in this build."; - } - auto check_transformed_graph = [](InferenceSessionWrapper& session) { - ASSERT_STATUS_OK(CheckMatMulNBitsSiluSimplifiedFusedGraph(session.GetGraph())); + ASSERT_STATUS_OK(CheckMatMulNBitsMlpSkipOutputPassthroughFusedGraph(session.GetGraph())); }; TransformerTester( - BuildMatMulNBitsSiluSimplifiedWebGpuPattern, + BuildMatMulNBitsMlpSkipOutputPassthroughWebGpuPattern, check_transformed_graph, TransformerLevel::Level1, TransformerLevel::Level2, 21, 1e-3, 1e-3, - std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), - {}, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + add_session_options, {}, std::move(webgpu_ep)); } diff --git a/onnxruntime/test/optimizer/matmul_nbits_qkv_sln_fusion_test.cc b/onnxruntime/test/optimizer/matmul_nbits_qkv_fusion_test.cc similarity index 63% rename from onnxruntime/test/optimizer/matmul_nbits_qkv_sln_fusion_test.cc rename to onnxruntime/test/optimizer/matmul_nbits_qkv_fusion_test.cc index b6e936ab883ad..853d641c755a1 100644 --- a/onnxruntime/test/optimizer/matmul_nbits_qkv_sln_fusion_test.cc +++ b/onnxruntime/test/optimizer/matmul_nbits_qkv_fusion_test.cc @@ -3,7 +3,7 @@ #include "core/graph/node_attr_utils.h" #include "core/optimizer/graph_transformer_mgr.h" -#include "core/optimizer/matmul_nbits_qkv_sln_fusion.h" +#include "core/optimizer/matmul_nbits_qkv_fusion.h" #include "core/optimizer/utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -36,18 +36,18 @@ NodeAttributes MakeMatMulNBitsAttrs(int64_t k, int64_t n, int64_t block_size, in return attrs; } -Status CheckMatMulNBitsQKVSimplifiedLayerNormFusedGraphImpl(const Graph& graph, bool expect_skip_sln_output) { +Status CheckMatMulNBitsQkvFusedGraphImpl(const Graph& graph, bool expect_skip_sln_output) { const auto op_to_count = CountOpsInGraph(graph); - if (OpCount(op_to_count, "com.microsoft.MatMulNBitsQKVSimplifiedLayerNorm") != 1 || + if (OpCount(op_to_count, "com.microsoft.MatMulNBitsQkv") != 1 || OpCount(op_to_count, "SimplifiedLayerNormalization") != 0 || OpCount(op_to_count, "com.microsoft.SkipSimplifiedLayerNormalization") != 0 || OpCount(op_to_count, "com.microsoft.MatMulNBits") != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Unexpected operator counts after MatMulNBitsQKVSimplifiedLayerNormFusion."); + "Unexpected operator counts after MatMulNBitsQkvFusion."); } for (const auto& node : graph.Nodes()) { - if (node.OpType() == "MatMulNBitsQKVSimplifiedLayerNorm") { + if (node.OpType() == "MatMulNBitsQkv") { ORT_RETURN_IF_NOT(node.Domain() == kMSDomain, "Fused node must be in com.microsoft domain."); ORT_RETURN_IF_NOT(node.GetExecutionProviderType() == kWebGpuExecutionProvider, "Fused node must be assigned to WebGPU EP."); @@ -60,15 +60,19 @@ Status CheckMatMulNBitsQKVSimplifiedLayerNormFusedGraphImpl(const Graph& graph, return Status::OK(); } -Status CheckMatMulNBitsQKVSimplifiedLayerNormFusedGraph(Graph& graph) { - return CheckMatMulNBitsQKVSimplifiedLayerNormFusedGraphImpl(static_cast(graph), false); +Status CheckMatMulNBitsQkvFusedGraph(Graph& graph) { + return CheckMatMulNBitsQkvFusedGraphImpl(static_cast(graph), false); } -Status CheckMatMulNBitsQKVSimplifiedLayerNormSkipFusedGraph(Graph& graph) { - return CheckMatMulNBitsQKVSimplifiedLayerNormFusedGraphImpl(static_cast(graph), true); +Status CheckMatMulNBitsQkvSkipFusedGraph(Graph& graph) { + return CheckMatMulNBitsQkvFusedGraphImpl(static_cast(graph), false); } -void BuildMatMulNBitsQKVSimplifiedLayerNormWebGpuPatternImpl(ModelTestBuilder& builder, bool with_skip_input) { +Status CheckMatMulNBitsQkvSkipOutputPassthroughFusedGraph(Graph& graph) { + return CheckMatMulNBitsQkvFusedGraphImpl(static_cast(graph), true); +} + +void BuildMatMulNBitsQkvWebGpuPatternImpl(ModelTestBuilder& builder, bool with_skip_input, bool with_skip_output) { constexpr int64_t k = 16; constexpr int64_t q_n = 8; constexpr int64_t kv_n = 4; @@ -106,17 +110,21 @@ void BuildMatMulNBitsQKVSimplifiedLayerNormWebGpuPatternImpl(ModelTestBuilder& b NodeArg* norm_out = builder.MakeIntermediate(std::vector{1, k}); NodeArg* optional_norm_output_1 = builder.MakeOptionalTensor(); NodeArg* optional_norm_output_2 = builder.MakeOptionalTensor(); - NodeArg* residual_out = with_skip_input ? builder.MakeIntermediate(std::vector{1, k}) : nullptr; + NodeArg* residual_out = (with_skip_input && with_skip_output) ? builder.MakeIntermediate(std::vector{1, k}) : nullptr; NodeArg* q_output = builder.MakeOutput(std::vector{1, q_n}); NodeArg* k_output = builder.MakeOutput(std::vector{1, kv_n}); NodeArg* v_output = builder.MakeOutput(std::vector{1, kv_n}); - NodeArg* residual_passthrough = with_skip_input ? builder.MakeOutput(std::vector{1, k}) : nullptr; + NodeArg* residual_passthrough = (with_skip_input && with_skip_output) ? builder.MakeOutput(std::vector{1, k}) : nullptr; NodeAttributes q_attrs = MakeMatMulNBitsAttrs(k, q_n, block_size, bits, accuracy_level); NodeAttributes kv_attrs = MakeMatMulNBitsAttrs(k, kv_n, block_size, bits, accuracy_level); Node& norm = with_skip_input - ? builder.AddNode("SkipSimplifiedLayerNormalization", {input, skip_input, norm_scale}, {norm_out, optional_norm_output_1, optional_norm_output_2, residual_out}, kMSDomain) + ? builder.AddNode("SkipSimplifiedLayerNormalization", + {input, skip_input, norm_scale}, + with_skip_output ? std::vector{norm_out, optional_norm_output_1, optional_norm_output_2, residual_out} + : std::vector{norm_out}, + kMSDomain) : builder.AddNode("SimplifiedLayerNormalization", {input, norm_scale}, {norm_out}); norm.AddAttribute("epsilon", 1e-6f); @@ -129,71 +137,111 @@ void BuildMatMulNBitsQKVSimplifiedLayerNormWebGpuPatternImpl(ModelTestBuilder& b SetWebGpuProvider(k_matmul); SetWebGpuProvider(v_matmul); - if (with_skip_input) { + if (with_skip_output) { Node& residual_identity = builder.AddNode("Identity", {residual_out}, {residual_passthrough}); SetWebGpuProvider(residual_identity); } } -void BuildMatMulNBitsQKVSimplifiedLayerNormWebGpuPattern(ModelTestBuilder& builder) { - BuildMatMulNBitsQKVSimplifiedLayerNormWebGpuPatternImpl(builder, false); +void BuildMatMulNBitsQkvWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsQkvWebGpuPatternImpl(builder, false, false); +} + +void BuildMatMulNBitsQkvSkipWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsQkvWebGpuPatternImpl(builder, true, false); } -void BuildMatMulNBitsQKVSimplifiedLayerNormSkipWebGpuPattern(ModelTestBuilder& builder) { - BuildMatMulNBitsQKVSimplifiedLayerNormWebGpuPatternImpl(builder, true); +void BuildMatMulNBitsQkvSkipOutputPassthroughWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsQkvWebGpuPatternImpl(builder, true, true); } } // namespace -TEST_F(GraphTransformationTests, MatMulNBitsQKVSimplifiedLayerNormFusionFusesWebGpuPattern) { +TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionFusesWebGpuPattern) { + ASSERT_STATUS_OK(TestGraphTransformer( + BuildMatMulNBitsQkvWebGpuPattern, + 21, + *logger_, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + TransformerLevel::Level2, + 1, + nullptr, + CheckMatMulNBitsQkvFusedGraph)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionMatchesUnfusedWebGpuResults) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsQkvFusedGraphImpl(session.GetGraph(), false)); + }; + + TransformerTester( + BuildMatMulNBitsQkvWebGpuPattern, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + {}, + {}, + std::move(webgpu_ep)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionFusesSkipWebGpuPattern) { ASSERT_STATUS_OK(TestGraphTransformer( - BuildMatMulNBitsQKVSimplifiedLayerNormWebGpuPattern, + BuildMatMulNBitsQkvSkipWebGpuPattern, 21, *logger_, - std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), TransformerLevel::Level2, 1, nullptr, - CheckMatMulNBitsQKVSimplifiedLayerNormFusedGraph)); + CheckMatMulNBitsQkvSkipFusedGraph)); } -TEST_F(GraphTransformationTests, MatMulNBitsQKVSimplifiedLayerNormFusionMatchesUnfusedWebGpuResults) { +TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionMatchesUnfusedSkipWebGpuResults) { auto webgpu_ep = DefaultWebGpuExecutionProvider(); if (!webgpu_ep) { GTEST_SKIP() << "WebGPU EP unavailable in this build."; } auto check_transformed_graph = [](InferenceSessionWrapper& session) { - ASSERT_STATUS_OK(CheckMatMulNBitsQKVSimplifiedLayerNormFusedGraphImpl(session.GetGraph(), false)); + ASSERT_STATUS_OK(CheckMatMulNBitsQkvFusedGraphImpl(session.GetGraph(), false)); }; TransformerTester( - BuildMatMulNBitsQKVSimplifiedLayerNormWebGpuPattern, + BuildMatMulNBitsQkvSkipWebGpuPattern, check_transformed_graph, TransformerLevel::Level1, TransformerLevel::Level2, 21, 1e-3, 1e-3, - std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), {}, {}, std::move(webgpu_ep)); } -TEST_F(GraphTransformationTests, MatMulNBitsQKVSimplifiedLayerNormFusionFusesSkipWebGpuPattern) { +TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionFusesSkipWebGpuPatternWithResidualOutputPassthrough) { ASSERT_STATUS_OK(TestGraphTransformer( - BuildMatMulNBitsQKVSimplifiedLayerNormSkipWebGpuPattern, + BuildMatMulNBitsQkvSkipOutputPassthroughWebGpuPattern, 21, *logger_, - std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), TransformerLevel::Level2, 1, nullptr, - CheckMatMulNBitsQKVSimplifiedLayerNormSkipFusedGraph)); + CheckMatMulNBitsQkvSkipOutputPassthroughFusedGraph)); } -TEST_F(GraphTransformationTests, MatMulNBitsQKVSimplifiedLayerNormFusionMatchesUnfusedSkipWebGpuResults) { +TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionMatchesUnfusedSkipWebGpuResultsWithResidualOutputPassthrough) { auto webgpu_ep = DefaultWebGpuExecutionProvider(); if (!webgpu_ep) { GTEST_SKIP() << "WebGPU EP unavailable in this build."; @@ -205,18 +253,18 @@ TEST_F(GraphTransformationTests, MatMulNBitsQKVSimplifiedLayerNormFusionMatchesU }; auto check_transformed_graph = [](InferenceSessionWrapper& session) { - ASSERT_STATUS_OK(CheckMatMulNBitsQKVSimplifiedLayerNormFusedGraphImpl(session.GetGraph(), true)); + ASSERT_STATUS_OK(CheckMatMulNBitsQkvFusedGraphImpl(session.GetGraph(), true)); }; TransformerTester( - BuildMatMulNBitsQKVSimplifiedLayerNormSkipWebGpuPattern, + BuildMatMulNBitsQkvSkipOutputPassthroughWebGpuPattern, check_transformed_graph, TransformerLevel::Level1, TransformerLevel::Level2, 21, 1e-3, 1e-3, - std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), add_session_options, {}, std::move(webgpu_ep)); From 30485ddfee73983ac029b99118a40f48bb299d05 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Wed, 29 Apr 2026 18:35:51 -0700 Subject: [PATCH 09/26] Move back to workgroup/tile_size default --- .../contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc | 6 ------ .../contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc | 6 ------ 2 files changed, 12 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc index 4dbb4d4e48f56..9990ba8ca0399 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc @@ -503,12 +503,6 @@ ONNX_OPERATOR_KERNEL_EX( uint32_t tile_size_k_vec = (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u; - if (context.AdapterInfo().vendor != std::string_view{"intel"} && N <= 2048) { - workgroup_size = 64; - tile_size = 4; - tile_size_k_vec = 16; - } - const uint32_t elements_in_value_b = components_b * (32u / onnxruntime::narrow(bits_)); const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b; const uint32_t k_tile_iterations = K / tile_size_k; diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc index 3ca9b1280d011..dafbc2a2c7781 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc @@ -481,12 +481,6 @@ Status MatMulNBitsQkv::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont uint32_t tile_size = 8; uint32_t tile_size_k_vec = (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u; - if (context.AdapterInfo().vendor != std::string_view{"intel"} && std::max(Nq, Nkv) <= 2048) { - workgroup_size = 64; - tile_size = 4; - tile_size_k_vec = 16; - } - const uint32_t elements_in_value_b = components_b * (32u / onnxruntime::narrow(bits_)); const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b; const uint32_t k_tile_iterations = K / tile_size_k; From 13bf9793a6f92cf4bcba8ebdde253f72c59e170f Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Thu, 30 Apr 2026 20:01:16 -0700 Subject: [PATCH 10/26] Copilot comments + Fix builds + Fix lint + Fusion diagrams --- .../webgpu/quantization/matmul_nbits.cc | 10 +- .../webgpu/quantization/matmul_nbits_common.h | 2 +- .../webgpu/quantization/matmul_nbits_mlp.cc | 121 +++++++++--------- .../webgpu/quantization/matmul_nbits_mlp.h | 8 +- .../webgpu/quantization/matmul_nbits_qkv.cc | 12 +- .../core/graph/contrib_ops/contrib_defs.cc | 5 +- .../core/optimizer/matmul_nbits_mlp_fusion.cc | 36 ++++-- .../core/optimizer/matmul_nbits_mlp_fusion.h | 16 +++ .../core/optimizer/matmul_nbits_qkv_fusion.cc | 30 +++-- .../core/optimizer/matmul_nbits_qkv_fusion.h | 16 +++ .../core/providers/webgpu/allocator.cc | 6 +- .../webgpu_matmul_nbits_decode.cc | 1 - .../optimizer/matmul_nbits_mlp_fusion_test.cc | 6 +- 13 files changed, 160 insertions(+), 109 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index be30e51290087..70c4ddfb19c04 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -255,11 +255,11 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, // WideTileProgram // This program is optimized for Block32 prefill using Tile16x128. const bool use_wide_tile_program = WouldApplyWideTileMatMulNBitsInCurrentDispatch(a, - K_op, - N_op, - block_size_op, - nbits, - has_weight_idx_indirect); + K_op, + N_op, + block_size_op, + nbits, + has_weight_idx_indirect); if (use_wide_tile_program) { // Enforce output components to 1. diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h index 58417bfca6918..883c6be02baa8 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h @@ -12,7 +12,7 @@ class Tensor; namespace webgpu { class ComputeContext; } // namespace webgpu -} +} // namespace onnxruntime namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc index 9990ba8ca0399..501209a213bde 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc @@ -23,7 +23,6 @@ namespace { constexpr uint32_t kFusedDecodeFastPathBits = 4u; constexpr uint32_t kFusedDecodeFastPathBlockSize = 32u; -constexpr float kSkipSimplifiedLayerNormEpsilon = 1e-05f; TensorShape GetOverrideShape(const TensorShape& shape, int components) { return TensorShape{shape.Size() / components}; @@ -177,11 +176,11 @@ class MatMulNBitsMlpDecodeProgram final : public Program()), MatMulNBitsMlp); - Status MatMulNBitsMlp::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { +Status MatMulNBitsMlp::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* a = context.Input(0); const Tensor* skip = context.Input(1); const Tensor* norm_scale = context.Input(2); @@ -479,10 +478,10 @@ ONNX_OPERATOR_KERNEL_EX( y); const bool would_use_wide_tile_unfused = WouldApplyWideTileMatMulNBitsInCurrentDispatch(a, - K_, - N_, - block_size_, - bits_); + K_, + N_, + block_size_, + bits_); const bool can_use_decode_fast_path = is_decode_fast_path_candidate && @@ -501,7 +500,7 @@ ONNX_OPERATOR_KERNEL_EX( uint32_t workgroup_size = 128; uint32_t tile_size = 8; uint32_t tile_size_k_vec = - (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u; + (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u; const uint32_t elements_in_value_b = components_b * (32u / onnxruntime::narrow(bits_)); const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b; @@ -519,14 +518,14 @@ ONNX_OPERATOR_KERNEL_EX( const uint32_t num_N_tile = CeilDiv(N, tile_size); MatMulNBitsMlpDecodeProgram program{tile_size, - has_gate_bias, - has_up_bias, - has_norm_input, - has_skip_input, - has_skip_output, - single_scale_weights, - tile_size_k_vec, - k_unroll_tiles}; + has_gate_bias, + has_up_bias, + has_norm_input, + has_skip_input, + has_skip_output, + single_scale_weights, + tile_size_k_vec, + k_unroll_tiles}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize(num_N_tile, 1, batch_count); program.AddInput({a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); @@ -551,7 +550,7 @@ ONNX_OPERATOR_KERNEL_EX( {num_N_tile}, {batch_count}, {has_skip_input ? onnxruntime::narrow(skip->Shape().Size()) : 0u}, - {kSkipSimplifiedLayerNormEpsilon}}) + {epsilon_}}) .CacheHint(single_scale_weights, has_gate_bias, has_up_bias, @@ -578,57 +577,57 @@ ONNX_OPERATOR_KERNEL_EX( if (skip != nullptr) { Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape()); - ORT_RETURN_IF_ERROR(ApplySkipSimplifiedLayerNorm(a, skip, norm_scale, kSkipSimplifiedLayerNormEpsilon, + ORT_RETURN_IF_ERROR(ApplySkipSimplifiedLayerNorm(a, skip, norm_scale, epsilon_, context, &normalized_a, input_skip_bias_sum)); return ApplyUnfusedMlp(&normalized_a, - gate_b, - gate_scales, - gate_bias, - up_b, - up_scales, - up_bias, - K_, - N_, - block_size_, - accuracy_level_, - bits_, - context, - y); + gate_b, + gate_scales, + gate_bias, + up_b, + up_scales, + up_bias, + K_, + N_, + block_size_, + accuracy_level_, + bits_, + context, + y); } if (norm_scale != nullptr) { Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape()); - ORT_RETURN_IF_ERROR(ApplySimplifiedLayerNorm(a, norm_scale, kSkipSimplifiedLayerNormEpsilon, context, &normalized_a)); + ORT_RETURN_IF_ERROR(ApplySimplifiedLayerNorm(a, norm_scale, epsilon_, context, &normalized_a)); return ApplyUnfusedMlp(&normalized_a, - gate_b, - gate_scales, - gate_bias, - up_b, - up_scales, - up_bias, - K_, - N_, - block_size_, - accuracy_level_, - bits_, - context, - y); + gate_b, + gate_scales, + gate_bias, + up_b, + up_scales, + up_bias, + K_, + N_, + block_size_, + accuracy_level_, + bits_, + context, + y); } return ApplyUnfusedMlp(a, - gate_b, - gate_scales, - gate_bias, - up_b, - up_scales, - up_bias, - K_, - N_, - block_size_, - accuracy_level_, - bits_, - context, - y); + gate_b, + gate_scales, + gate_bias, + up_b, + up_scales, + up_bias, + K_, + N_, + block_size_, + accuracy_level_, + bits_, + context, + y); } } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h index 52333d293dce1..c6ce500980ee9 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h @@ -16,18 +16,19 @@ using onnxruntime::webgpu::ComputeContext; class MatMulNBitsMlp final : public WebGpuKernel { public: - explicit MatMulNBitsMlp(const OpKernelInfo& info) : WebGpuKernel(info) { + explicit MatMulNBitsMlp(const OpKernelInfo& info) : WebGpuKernel(info) { K_ = info.GetAttr("K"); N_ = info.GetAttr("N"); block_size_ = info.GetAttr("block_size"); bits_ = info.GetAttr("bits"); accuracy_level_ = info.GetAttrOrDefault("accuracy_level", 4); + epsilon_ = info.GetAttrOrDefault("epsilon", 1e-5f); ORT_ENFORCE(info.GetAttr("activation", &activation_).IsOK(), - "MatMulNBitsMlp requires the 'activation' attribute."); + "MatMulNBitsMlp requires the 'activation' attribute."); ORT_ENFORCE(bits_ == 4 || bits_ == 8 || bits_ == 2, "Only 4b/8b/2b quantization is supported for MatMulNBitsMlp op."); ORT_ENFORCE(activation_ == "silu", - "MatMulNBitsMlp currently only supports activation='silu'."); + "MatMulNBitsMlp currently only supports activation='silu'."); } Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; @@ -38,6 +39,7 @@ class MatMulNBitsMlp final : public WebGpuKernel { int64_t block_size_; int64_t accuracy_level_; int64_t bits_; + float epsilon_; std::string activation_; }; diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc index dafbc2a2c7781..b260489227243 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc @@ -212,14 +212,14 @@ class MatMulNBitsQkvDecodeProgram final const auto& v_b = shader.AddInput("v_b"); const auto& v_scales_b = shader.AddInput("v_scales_b"); const auto& q_output = shader.AddOutput("q_output", - ShaderUsage::UseValueTypeAlias | - ShaderUsage::UseElementTypeAlias); + ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias); const auto& k_output = shader.AddOutput("k_output", - ShaderUsage::UseValueTypeAlias | - ShaderUsage::UseElementTypeAlias); + ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias); const auto& v_output = shader.AddOutput("v_output", - ShaderUsage::UseValueTypeAlias | - ShaderUsage::UseElementTypeAlias); + ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias); const auto* input_skip_bias_sum = has_skip_output_ ? &shader.AddOutput("input_skip_bias_sum", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias) : nullptr; const auto& skip_var = skip != nullptr ? *skip : a; const auto& input_skip_bias_sum_var = input_skip_bias_sum != nullptr ? *input_skip_bias_sum : q_output; diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 4cb5b19f54018..e8e814b10ac6b 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3666,6 +3666,9 @@ When fused from SkipSimplifiedLayerNormalization, the optional residual-sum outp .Attr("activation", "Activation applied to the gate projection.", AttributeProto::STRING) + .Attr("epsilon", + "Epsilon used by the optional fused (Skip)SimplifiedLayerNormalization. Defaults to 1e-5.", + AttributeProto::FLOAT, 1e-5f) .Input(0, "A", "The shared input tensor.", "T1") .Input(1, "skip", "Optional skip input used by SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional) .Input(2, "norm_scale", "Optional RMSNorm scale with shape [K] used by SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional) @@ -3786,7 +3789,7 @@ This operator is intended as a decode-oriented QKV fusion primitive. "Constrain input and output types to float tensors.") .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - for (int output_index = 0; output_index < ctx.getNumOutputs(); ++output_index) { + for (size_t output_index = 0; output_index < ctx.getNumOutputs(); ++output_index) { propagateElemTypeFromInputToOutput(ctx, 0, output_index); } diff --git a/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc index ee4f1bc63aa3a..50ba18593089b 100644 --- a/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc @@ -94,6 +94,11 @@ int64_t GetIntAttr(const Node& node, const char* name, int64_t default_value, bo return attr->i(); } +float GetFloatAttr(const Node& node, const char* name, float default_value) { + const auto* attr = graph_utils::GetNodeAttribute(node, name); + return attr == nullptr ? default_value : attr->f(); +} + bool HasSingleNonGraphConsumer(const Graph& graph, const Node& node) { return !graph.NodeProducesGraphOutput(node) && optimizer_utils::CheckOutputEdges(graph, node, 1); } @@ -186,7 +191,9 @@ bool IsFuseCandidate(const Graph& graph, const int64_t gate_accuracy_level = GetIntAttr(gate_matmul, "accuracy_level", 0); const int64_t up_accuracy_level = GetIntAttr(up_matmul, "accuracy_level", 0); - return gate_k == up_k && gate_n == up_n && gate_bits == up_bits && gate_block_size == up_block_size && + return gate_k == up_k && gate_n == up_n && + gate_bits == up_bits && gate_bits == 4 && + gate_block_size == up_block_size && gate_block_size == 32 && gate_accuracy_level == up_accuracy_level; } @@ -256,14 +263,14 @@ Status MatMulNBitsMlpFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l } LOGS(logger, VERBOSE) << "MatMulNBitsMlpFusion: matched candidate output_mul='" << node.Name() - << "' gate='" << gate_matmul->Name() << "' up='" << up_matmul->Name() - << "' sigmoid='" << sigmoid->Name() << "' activation_mul='" << silu_mul->Name() - << "' attrs={K=" << GetIntAttr(*gate_matmul, "K", -1, true) - << ", N=" << GetIntAttr(*gate_matmul, "N", -1, true) - << ", bits=" << GetIntAttr(*gate_matmul, "bits", 4) - << ", block_size=" << GetIntAttr(*gate_matmul, "block_size", -1, true) - << ", accuracy_level=" << GetIntAttr(*gate_matmul, "accuracy_level", 0) - << "}"; + << "' gate='" << gate_matmul->Name() << "' up='" << up_matmul->Name() + << "' sigmoid='" << sigmoid->Name() << "' activation_mul='" << silu_mul->Name() + << "' attrs={K=" << GetIntAttr(*gate_matmul, "K", -1, true) + << ", N=" << GetIntAttr(*gate_matmul, "N", -1, true) + << ", bits=" << GetIntAttr(*gate_matmul, "bits", 4) + << ", block_size=" << GetIntAttr(*gate_matmul, "block_size", -1, true) + << ", accuracy_level=" << GetIntAttr(*gate_matmul, "accuracy_level", 0) + << "}"; if ((!gate_matmul->GetExecutionProviderType().empty() && gate_matmul->GetExecutionProviderType() != kWebGpuExecutionProvider) || (!up_matmul->GetExecutionProviderType().empty() && up_matmul->GetExecutionProviderType() != kWebGpuExecutionProvider) || @@ -289,6 +296,7 @@ Status MatMulNBitsMlpFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l utils::SetNodeAttribute(utils::MakeAttribute("block_size", GetIntAttr(*gate_matmul, "block_size", -1, true)), attrs); utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", GetIntAttr(*gate_matmul, "accuracy_level", 0)), attrs); utils::SetNodeAttribute(utils::MakeAttribute(kActivationAttrName, std::string{kSupportedActivation}), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("epsilon", GetFloatAttr(*norm, "epsilon", 1e-5f)), attrs); NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr); const bool is_skip_sln = norm != nullptr && IsSupportedSkipSimplifiedLayerNormalization(*norm); @@ -318,6 +326,8 @@ Status MatMulNBitsMlpFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l const auto norm_output_edges = preserve_skip_output ? graph_utils::GraphEdge::GetNodeOutputEdges(*norm) : std::vector{}; + const std::string output_mul_name = node.Name(); + graph_utils::RemoveNodeOutputEdges(graph, const_cast(*norm)); graph.RemoveNode(norm->Index()); graph_utils::RemoveNodeOutputEdges(graph, const_cast(*gate_matmul)); @@ -332,8 +342,8 @@ Status MatMulNBitsMlpFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l graph.RemoveNode(node.Index()); Node& fused_node = graph.AddNode(graph.GenerateNodeName("MatMulNBitsMlp"), - "MatMulNBitsMlp", - "fused MatMulNBits gated MLP projections", + "MatMulNBitsMlp", + "fused MatMulNBits gated MLP projections", fused_inputs, fused_outputs, &attrs, @@ -341,7 +351,7 @@ Status MatMulNBitsMlpFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l fused_node.SetExecutionProviderType(kWebGpuExecutionProvider); LOGS(logger, VERBOSE) << "MatMulNBitsMlpFusion: created fused node '" << fused_node.Name() - << "' from output_mul='" << node.Name() << "'"; + << "' from output_mul='" << output_mul_name << "'"; for (const auto& input_edge : norm_input_edges) { int fused_input_index = input_edge.dst_arg_index; @@ -386,4 +396,4 @@ Status MatMulNBitsMlpFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l return Status::OK(); } -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h index 2208df0f3d3e4..d201256b93f91 100644 --- a/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h +++ b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h @@ -7,6 +7,22 @@ namespace onnxruntime { +// Fuses the SwiGLU MLP block (gate / up / down MatMulNBits projections around a +// SimplifiedLayerNormalization anchor) into a single MatMulNBitsMlp contrib op: +// +// ... -> [Skip]SimplifiedLayerNormalization -+-> MatMulNBits (gate) -+-> Sigmoid -+ +// | | | v +// | | +----------> Mul (silu) -+ +// | +-> MatMulNBits (up) ---------------------------+--> Mul -> MatMulNBits (down) -> out +// +--> (optional) skip residual passthrough --> downstream consumers +// +// becomes +// +// ... -> [Skip]SimplifiedLayerNormalization --> MatMulNBitsMlp(activation="silu") -+-> out +// +-> (optional) residual passthrough +// +// Only activation="silu" (i.e. x * Sigmoid(x)) is matched / emitted, and the fusion is restricted +// to the WebGPU EP because MatMulNBitsMlp is a WebGPU-only contrib op. class MatMulNBitsMlpFusion : public GraphTransformer { public: explicit MatMulNBitsMlpFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept diff --git a/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc index 1eeb38058f34c..e29cd3e4e0030 100644 --- a/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc @@ -149,8 +149,8 @@ bool IsFuseCandidate(const Node& norm, const QkvNodes& qkv) { const int64_t v_accuracy_level = GetIntAttr(*qkv.v, "accuracy_level", 0); return q_k == k_k && q_k == v_k && - q_bits == k_bits && q_bits == v_bits && - q_block_size == k_block_size && q_block_size == v_block_size && + q_bits == k_bits && q_bits == v_bits && q_bits == 4 && + q_block_size == k_block_size && q_block_size == v_block_size && q_block_size == 32 && q_accuracy_level == k_accuracy_level && q_accuracy_level == v_accuracy_level; } @@ -190,11 +190,11 @@ Status MatMulNBitsQkvFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l const bool is_skip_sln = IsSupportedSkipSimplifiedLayerNormalization(node); LOGS(logger, VERBOSE) << "MatMulNBitsQkvFusion: matched norm='" << node.Name() - << "' q='" << qkv_nodes->q->Name() << "' k='" << qkv_nodes->k->Name() - << "' v='" << qkv_nodes->v->Name() << "' attrs={K=" << K - << ", Nq=" << Nq << ", Nkv=" << Nkv << ", bits=" << bits - << ", block_size=" << block_size << ", accuracy_level=" << accuracy_level - << ", epsilon=" << epsilon << ", skip_sln=" << is_skip_sln << "}"; + << "' q='" << qkv_nodes->q->Name() << "' k='" << qkv_nodes->k->Name() + << "' v='" << qkv_nodes->v->Name() << "' attrs={K=" << K + << ", Nq=" << Nq << ", Nkv=" << Nkv << ", bits=" << bits + << ", block_size=" << block_size << ", accuracy_level=" << accuracy_level + << ", epsilon=" << epsilon << ", skip_sln=" << is_skip_sln << "}"; NodeAttributes attrs; utils::SetNodeAttribute(utils::MakeAttribute("K", K), attrs); @@ -228,11 +228,17 @@ Status MatMulNBitsQkvFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l fused_outputs.push_back(const_cast(node.OutputDefs()[3])); } + const bool has_residual_output = is_skip_sln && HasProducedOutput(node, 3); + const std::string norm_name = node.Name(); + const std::string q_name = qkv_nodes->q->Name(); + const std::string k_name = qkv_nodes->k->Name(); + const std::string v_name = qkv_nodes->v->Name(); + const auto norm_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(node); const auto q_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(*qkv_nodes->q); const auto k_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(*qkv_nodes->k); const auto v_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(*qkv_nodes->v); - const auto norm_output_edges = is_skip_sln && HasProducedOutput(node, 3) + const auto norm_output_edges = has_residual_output ? graph_utils::GraphEdge::GetNodeOutputEdges(node) : std::vector{}; graph_utils::RemoveNodeOutputEdges(graph, const_cast(*qkv_nodes->q)); @@ -245,7 +251,7 @@ Status MatMulNBitsQkvFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l graph.RemoveNode(node.Index()); Node& fused_node = graph.AddNode(graph.GenerateNodeName("MatMulNBitsQkv"), - "MatMulNBitsQkv", + "MatMulNBitsQkv", "fused SimplifiedLayerNormalization with Q/K/V MatMulNBits projections", fused_inputs, fused_outputs, @@ -254,8 +260,8 @@ Status MatMulNBitsQkvFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l fused_node.SetExecutionProviderType(kWebGpuExecutionProvider); LOGS(logger, VERBOSE) << "MatMulNBitsQkvFusion: created fused node '" << fused_node.Name() - << "' from norm='" << node.Name() << "' q='" << qkv_nodes->q->Name() - << "' k='" << qkv_nodes->k->Name() << "' v='" << qkv_nodes->v->Name() << "'"; + << "' from norm='" << norm_name << "' q='" << q_name + << "' k='" << k_name << "' v='" << v_name << "'"; for (const auto& input_edge : norm_input_edges) { int fused_input_index = input_edge.dst_arg_index; @@ -275,7 +281,7 @@ Status MatMulNBitsQkvFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l for (const auto& output_edge : v_output_edges) { graph.AddEdge(fused_node.Index(), output_edge.dst_node, 2, output_edge.dst_arg_index); } - if (is_skip_sln && HasProducedOutput(node, 3)) { + if (has_residual_output) { for (const auto& output_edge : norm_output_edges) { if (output_edge.src_arg_index == 3) { graph.AddEdge(fused_node.Index(), output_edge.dst_node, 3, output_edge.dst_arg_index); diff --git a/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.h b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.h index 2e028c28190b4..fcbbb78457f52 100644 --- a/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.h +++ b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.h @@ -7,6 +7,22 @@ namespace onnxruntime { +// Fuses three sibling MatMulNBits Q/K/V projections that share a SimplifiedLayerNormalization +// (or SkipSimplifiedLayerNormalization) anchor into a single MatMulNBitsQkv contrib op: +// +// ... -> [Skip]SimplifiedLayerNormalization -+-> MatMulNBits (Q proj) -+ +// | +-> MatMulNBits (K proj) -+--> downstream consumers +// | +-> MatMulNBits (V proj) -+ +// +--> (optional) skip residual passthrough --> downstream consumers +// +// becomes +// +// ... -> [Skip]SimplifiedLayerNormalization --> MatMulNBitsQkv -+-> Q out +// +-> K out +// +-> V out +// +-> (optional) residual passthrough +// +// The fusion is restricted to the WebGPU EP because MatMulNBitsQkv is a WebGPU-only contrib op. class MatMulNBitsQkvFusion : public GraphTransformer { public: explicit MatMulNBitsQkvFusion( diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc index 7565cc8d52a87..08cc9455fd34b 100644 --- a/onnxruntime/core/providers/webgpu/allocator.cc +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -9,7 +9,7 @@ namespace onnxruntime { namespace webgpu { GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator) - : GpuBufferAllocator([buffer_manager_ptr = &buffer_manager]() -> const BufferManager& { return *buffer_manager_ptr; }, is_read_only_allocator) { + : GpuBufferAllocator([buffer_manager_ptr = &buffer_manager]() -> const BufferManager& { return *buffer_manager_ptr; }, is_read_only_allocator) { } GpuBufferAllocator::GpuBufferAllocator(std::function buffer_manager_getter, bool is_read_only_allocator) @@ -19,8 +19,8 @@ GpuBufferAllocator::GpuBufferAllocator(std::function buf : OrtAllocatorType::OrtDeviceAllocator, WebGpuDevice, OrtMemTypeDefault)), - buffer_manager_getter_{std::move(buffer_manager_getter)}, - mapped_at_creation_{is_read_only_allocator && buffer_manager_getter_().SupportsUMA()} { + buffer_manager_getter_{std::move(buffer_manager_getter)}, + mapped_at_creation_{is_read_only_allocator && buffer_manager_getter_().SupportsUMA()} { } void* GpuBufferAllocator::Alloc(size_t size) { diff --git a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc index a95faf04f37d3..3bb67b932dfe7 100644 --- a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc +++ b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc @@ -1678,4 +1678,3 @@ BENCHMARK(BM_WebGpuMatMulNBitsMlpSkipPassthroughDecodeFused) ->Unit(benchmark::TimeUnit::kMicrosecond); } // namespace - diff --git a/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc b/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc index 4449f989e9e55..2ab6d15f3f8be 100644 --- a/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc +++ b/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc @@ -69,11 +69,11 @@ Status CheckMatMulNBitsMlpFusedGraphImpl(const Graph& graph, NormAnchorKind norm const bool has_skip = node.InputDefs()[1] != nullptr && !node.InputDefs()[1]->Name().empty(); const bool has_norm_scale = node.InputDefs()[2] != nullptr && !node.InputDefs()[2]->Name().empty(); ORT_RETURN_IF_NOT(has_skip == (norm_anchor_kind == NormAnchorKind::kSkipSimplified), - "Unexpected skip input presence on fused node."); + "Unexpected skip input presence on fused node."); ORT_RETURN_IF_NOT(has_norm_scale, - "Expected norm_scale input on fused node."); + "Expected norm_scale input on fused node."); ORT_RETURN_IF_NOT(node.OutputDefs().size() == 1u, - "Non-passthrough fusion should expose only the Y output."); + "Non-passthrough fusion should expose only the Y output."); const auto* activation_attr = graph_utils::GetNodeAttribute(node, "activation"); ORT_RETURN_IF_NOT(activation_attr != nullptr && activation_attr->s() == kExpectedActivation, From d1090c86403a583f49fb9e790214abf5c7cfc90d Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Thu, 30 Apr 2026 21:19:08 -0700 Subject: [PATCH 11/26] Fix test --- .../test/optimizer/matmul_nbits_mlp_fusion_test.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc b/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc index 2ab6d15f3f8be..cc59679aa4d63 100644 --- a/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc +++ b/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc @@ -136,9 +136,9 @@ Status CheckMatMulNBitsMlpSkipOutputPassthroughFusedGraph(const Graph& graph) { void BuildMatMulNBitsMlpWebGpuPatternImpl(ModelTestBuilder& builder, NormAnchorKind norm_anchor_kind, SkipOutputKind skip_output_kind = SkipOutputKind::kNone) { - constexpr int64_t k = 16; + constexpr int64_t k = 32; constexpr int64_t n = 8; - constexpr int64_t block_size = 16; + constexpr int64_t block_size = 32; constexpr int64_t bits = 4; constexpr int64_t accuracy_level = 4; constexpr int64_t blob_size = block_size * bits / 8; @@ -146,6 +146,10 @@ void BuildMatMulNBitsMlpWebGpuPatternImpl(ModelTestBuilder& builder, NodeArg* input = builder.MakeInput( std::vector{1, k}, std::vector{ + MLFloat16(-1.0f), MLFloat16(-0.875f), MLFloat16(-0.75f), MLFloat16(-0.625f), + MLFloat16(-0.5f), MLFloat16(-0.375f), MLFloat16(-0.25f), MLFloat16(-0.125f), + MLFloat16(0.125f), MLFloat16(0.25f), MLFloat16(0.375f), MLFloat16(0.5f), + MLFloat16(0.625f), MLFloat16(0.75f), MLFloat16(0.875f), MLFloat16(1.0f), MLFloat16(-1.0f), MLFloat16(-0.875f), MLFloat16(-0.75f), MLFloat16(-0.625f), MLFloat16(-0.5f), MLFloat16(-0.375f), MLFloat16(-0.25f), MLFloat16(-0.125f), MLFloat16(0.125f), MLFloat16(0.25f), MLFloat16(0.375f), MLFloat16(0.5f), From ffacd4c54c133be7d78708e179f0d630515a0556 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Thu, 30 Apr 2026 22:19:13 -0700 Subject: [PATCH 12/26] Fix builds --- cmake/onnxruntime_unittests.cmake | 10 ++++++++++ .../webgpu/quantization/matmul_nbits_mlp.cc | 15 --------------- .../webgpu_matmul_nbits_decode.cc | 19 ++----------------- 3 files changed, 12 insertions(+), 32 deletions(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 3ab96b50b9649..d0bd946f5ce7a 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1396,6 +1396,16 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) "$<$>:/utf-8>") endif() target_link_libraries(onnxruntime_benchmark PRIVATE onnx_test_runner_common benchmark::benchmark ${onnx_test_libs}) + if (onnxruntime_USE_WEBGPU AND + NOT onnxruntime_USE_EP_API_ADAPTERS AND + NOT onnxruntime_BUILD_DAWN_SHARED_LIBRARY AND + NOT onnxruntime_USE_EXTERNAL_DAWN AND + NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + # webgpu_matmul_nbits_decode.cc uses Dawn's native API directly to create the + # WebGPU device shared with ORT, so we need the Dawn headers and libs visible + # to this benchmark target (they are linked PRIVATE to onnxruntime_providers_webgpu). + target_link_libraries(onnxruntime_benchmark PRIVATE dawn::dawn_native dawn::dawn_proc) + endif() add_dependencies(onnxruntime_benchmark ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_benchmark PROPERTIES FOLDER "ONNXRuntimeTest") diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc index 501209a213bde..c89850567e7ad 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc @@ -124,21 +124,6 @@ Status ApplySkipSimplifiedLayerNorm(const Tensor* x, return context.RunProgram(program); } -Status ApplyUnfusedSiluMul(const Tensor* a, - const Tensor* gate_b, - const Tensor* gate_scales, - const Tensor* gate_bias, - const Tensor* up_b, - const Tensor* up_scales, - const Tensor* up_bias, - int64_t K, - int64_t N, - int64_t block_size, - int64_t accuracy_level, - int64_t bits, - onnxruntime::webgpu::ComputeContext& context, - Tensor* y); - class MatMulNBitsMlpDecodeProgram final : public Program { public: MatMulNBitsMlpDecodeProgram(uint32_t tile_size, diff --git a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc index 3bb67b932dfe7..69a4785c12a5a 100644 --- a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc +++ b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc @@ -32,7 +32,6 @@ extern OrtEnv* env; extern const OrtApi* g_ort; namespace { -constexpr const char* kMatMulNBitsAutoTunerEnvVar = "ORT_WEBGPU_MATMUL_NBITS_ENABLE_AUTO_TUNER"; constexpr const char* kDecodeBenchmarkModeEnvVar = "ORT_WEBGPU_MATMUL_NBITS_BENCHMARK_MODE"; constexpr const char* kDecodeBenchmarkGraphCaptureEnvVar = "ORT_WEBGPU_MATMUL_NBITS_ENABLE_GRAPH_CAPTURE"; constexpr const char* kDecodeBenchmarkOptimizedModelPathEnvVar = "ORT_WEBGPU_MATMUL_NBITS_OPTIMIZED_MODEL_PATH"; @@ -46,7 +45,7 @@ enum class DecodeBenchmarkMode { kCorrectness, }; -bool IsMatMulNBitsAutoTunerEnabled(); +bool IsGraphCaptureBenchmarkEnabled(); bool IsGraphCaptureBenchmarkEnabled(); bool IsVerboseSessionLogEnabled(); std::string GetOptimizedModelPath(); @@ -167,7 +166,6 @@ bool IsDecodeBenchmarkPerfMode() { std::string GetDecodeBenchmarkLabel(const char* shape_label = nullptr) { const char* mode_label = IsDecodeBenchmarkPerfMode() ? "perf" : "correctness"; - const char* tuner_label = IsMatMulNBitsAutoTunerEnabled() ? "tuner_on" : "tuner_off"; const char* graph_label = IsGraphCaptureBenchmarkEnabled() ? "graph_on" : "graph_off"; std::ostringstream stream; @@ -175,21 +173,10 @@ std::string GetDecodeBenchmarkLabel(const char* shape_label = nullptr) { if (shape_label != nullptr && shape_label[0] != '\0') { stream << '_' << shape_label; } - stream << '_' << mode_label << "_auto_gpu_" << tuner_label << '_' << graph_label; + stream << '_' << mode_label << "_auto_gpu_" << graph_label; return stream.str(); } -bool IsMatMulNBitsAutoTunerEnabled() { - std::string auto_tuner_env = onnxruntime::Env::Default().GetEnvironmentVar(kMatMulNBitsAutoTunerEnvVar); - if (auto_tuner_env.empty()) { - return false; - } - - std::transform(auto_tuner_env.begin(), auto_tuner_env.end(), auto_tuner_env.begin(), - [](unsigned char value) { return static_cast(std::tolower(value)); }); - return auto_tuner_env != "0" && auto_tuner_env != "false" && auto_tuner_env != "off"; -} - bool IsGraphCaptureBenchmarkEnabled() { std::string graph_capture_env = onnxruntime::Env::Default().GetEnvironmentVar(kDecodeBenchmarkGraphCaptureEnvVar); if (graph_capture_env.empty()) { @@ -803,7 +790,6 @@ std::string GetMlpDecodeBenchmarkLabel(MlpDecodeBenchmarkVariant variant, MlpNor stream << "fp16_mlp_decode_" << GetMlpNormKindLabel(norm_kind) << '_' << GetMlpVariantLabel(variant) << '_' << (IsDecodeBenchmarkPerfMode() ? "perf" : "correctness") << '_' << "auto_gpu_" - << (IsMatMulNBitsAutoTunerEnabled() ? "tuner_on" : "tuner_off") << '_' << (IsGraphCaptureBenchmarkEnabled() ? "graph_on" : "graph_off"); return stream.str(); } @@ -830,7 +816,6 @@ std::string GetQkvDecodeBenchmarkLabel(QkvDecodeBenchmarkVariant variant, QkvNor stream << "fp16_qkv_norm_" << GetQkvNormKindLabel(norm_kind) << '_' << GetQkvVariantLabel(variant) << '_' << (IsDecodeBenchmarkPerfMode() ? "perf" : "correctness") << '_' << "auto_gpu_" - << (IsMatMulNBitsAutoTunerEnabled() ? "tuner_on" : "tuner_off") << '_' << (IsGraphCaptureBenchmarkEnabled() ? "graph_on" : "graph_off"); return stream.str(); } From 92874ce7ac79fbfeba40488f1e1999e01d69ff76 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Thu, 30 Apr 2026 23:36:27 -0700 Subject: [PATCH 13/26] Fixes --- .../webgpu/quantization/matmul_nbits_common.cc | 10 +++++----- .../onnx/microbenchmark/webgpu_matmul_nbits_decode.cc | 8 ++++++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc index ef40e224d8fa7..e615b90577f61 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc @@ -85,12 +85,12 @@ bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(const Tensor* a, return false; } - const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size()); const uint32_t M = onnxruntime::narrow(helper.M()); - const uint32_t dispatch_M = override_M > 0 ? override_M : M; - const uint32_t N = onnxruntime::narrow(helper.N()); - const uint32_t K = onnxruntime::narrow(helper.K()); - const uint32_t block_size = onnxruntime::narrow(block_size_op); + [[maybe_unused]] const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size()); + [[maybe_unused]] const uint32_t dispatch_M = override_M > 0 ? override_M : M; + [[maybe_unused]] const uint32_t N = onnxruntime::narrow(helper.N()); + [[maybe_unused]] const uint32_t K = onnxruntime::narrow(helper.K()); + [[maybe_unused]] const uint32_t block_size = onnxruntime::narrow(block_size_op); #if !defined(__wasm__) int32_t local_subgroup_matrix_config_index = -1; diff --git a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc index 69a4785c12a5a..a98c3f04d0eaf 100644 --- a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc +++ b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc @@ -1,6 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// This benchmark uses Dawn (WebGPU) headers directly, which are only available +// when the WebGPU EP is built (USE_WEBGPU is defined for that EP). On other +// builds the file compiles to an empty translation unit so the benchmark +// target can still link. +#ifdef USE_WEBGPU + #include #include @@ -1663,3 +1669,5 @@ BENCHMARK(BM_WebGpuMatMulNBitsMlpSkipPassthroughDecodeFused) ->Unit(benchmark::TimeUnit::kMicrosecond); } // namespace + +#endif // USE_WEBGPU From a7899c6a43060caf90dc1328f78f29cf2626bbc2 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 1 May 2026 20:38:05 -0700 Subject: [PATCH 14/26] Slim PR: drop benchmark harness, lazy buffer-mgr fix, consteval fix, shader diagnostics These changes are kept on hari/webgpu_perf_1_full locally. The lazy buffer-mgr fix is being submitted as a separate PR (branch hari/webgpu_graph_capture_buffer_fix) because it is an independent correctness fix for a pre-existing latent bug, exposed but not introduced by these fusions. --- cmake/onnxruntime_unittests.cmake | 13 +- .../core/providers/webgpu/allocator.cc | 14 +- onnxruntime/core/providers/webgpu/allocator.h | 5 +- .../core/providers/webgpu/compute_context.h | 6 - .../core/providers/webgpu/program_manager.cc | 84 +- .../core/providers/webgpu/webgpu_context.cc | 10 - .../core/providers/webgpu/webgpu_context.h | 1 - .../webgpu/webgpu_execution_provider.cc | 25 +- .../webgpu/webgpu_execution_provider.h | 1 - onnxruntime/core/session/ort_version_check.h | 21 +- onnxruntime/test/onnx/microbenchmark/main.cc | 2 +- .../webgpu_matmul_nbits_decode.cc | 1673 ----------------- 12 files changed, 28 insertions(+), 1827 deletions(-) delete mode 100644 onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index d0bd946f5ce7a..bd12b50b7af43 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1373,8 +1373,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) ${BENCHMARK_DIR}/activation.cc ${BENCHMARK_DIR}/quantize.cc ${BENCHMARK_DIR}/reduceminmax.cc - ${BENCHMARK_DIR}/layer_normalization.cc - ${BENCHMARK_DIR}/webgpu_matmul_nbits_decode.cc) + ${BENCHMARK_DIR}/layer_normalization.cc) target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} ${ONNXRUNTIME_ROOT}/core/mlas/inc) target_compile_definitions(onnxruntime_benchmark PRIVATE BENCHMARK_STATIC_DEFINE) target_compile_definitions(onnxruntime_benchmark PRIVATE ${mlas_private_compile_definitions}) @@ -1396,16 +1395,6 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) "$<$>:/utf-8>") endif() target_link_libraries(onnxruntime_benchmark PRIVATE onnx_test_runner_common benchmark::benchmark ${onnx_test_libs}) - if (onnxruntime_USE_WEBGPU AND - NOT onnxruntime_USE_EP_API_ADAPTERS AND - NOT onnxruntime_BUILD_DAWN_SHARED_LIBRARY AND - NOT onnxruntime_USE_EXTERNAL_DAWN AND - NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - # webgpu_matmul_nbits_decode.cc uses Dawn's native API directly to create the - # WebGPU device shared with ORT, so we need the Dawn headers and libs visible - # to this benchmark target (they are linked PRIVATE to onnxruntime_providers_webgpu). - target_link_libraries(onnxruntime_benchmark PRIVATE dawn::dawn_native dawn::dawn_proc) - endif() add_dependencies(onnxruntime_benchmark ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_benchmark PROPERTIES FOLDER "ONNXRuntimeTest") diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc index 08cc9455fd34b..3e1b87821fe2f 100644 --- a/onnxruntime/core/providers/webgpu/allocator.cc +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -9,18 +9,14 @@ namespace onnxruntime { namespace webgpu { GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator) - : GpuBufferAllocator([buffer_manager_ptr = &buffer_manager]() -> const BufferManager& { return *buffer_manager_ptr; }, is_read_only_allocator) { -} - -GpuBufferAllocator::GpuBufferAllocator(std::function buffer_manager_getter, bool is_read_only_allocator) : IAllocator( OrtMemoryInfo(WEBGPU_BUFFER, is_read_only_allocator ? OrtAllocatorType::OrtReadOnlyAllocator : OrtAllocatorType::OrtDeviceAllocator, WebGpuDevice, OrtMemTypeDefault)), - buffer_manager_getter_{std::move(buffer_manager_getter)}, - mapped_at_creation_{is_read_only_allocator && buffer_manager_getter_().SupportsUMA()} { + buffer_manager_{buffer_manager}, + mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} { } void* GpuBufferAllocator::Alloc(size_t size) { @@ -30,17 +26,15 @@ void* GpuBufferAllocator::Alloc(size_t size) { stats_.num_allocs++; - const auto& buffer_manager = buffer_manager_getter_(); - wgpu::BufferUsage usage = mapped_at_creation_ ? wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapWrite : wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Indirect; - return buffer_manager.Create(size, usage); + return buffer_manager_.Create(size, usage); } void GpuBufferAllocator::Free(void* p) { if (p != nullptr) { - buffer_manager_getter_().Release(static_cast(p)); + buffer_manager_.Release(static_cast(p)); stats_.num_allocs--; } } diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h index fadfc8c86cfc4..74b3d669fcf3b 100644 --- a/onnxruntime/core/providers/webgpu/allocator.h +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -3,8 +3,6 @@ #pragma once -#include - #include "core/framework/allocator.h" #include "core/framework/ortdevice.h" @@ -21,7 +19,6 @@ inline constexpr OrtDevice WebGpuDevice{OrtDevice::GPU, class GpuBufferAllocator : public IAllocator { public: GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator); - GpuBufferAllocator(std::function buffer_manager_getter, bool is_read_only_allocator); virtual void* Alloc(size_t size) override; virtual void Free(void* p) override; @@ -29,7 +26,7 @@ class GpuBufferAllocator : public IAllocator { private: AllocatorStats stats_; - std::function buffer_manager_getter_; + const BufferManager& buffer_manager_; bool mapped_at_creation_; }; diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 87d61779213b7..632e04a36c7bf 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -38,7 +38,6 @@ class ComputeContextBase { // This ensures no access to BufferManager from other classes, avoiding // potential misuse. friend class WebGpuContext; - friend class ComputeContextBase; private: static const webgpu::BufferManager& Get(const ComputeContextBase& context); @@ -122,11 +121,6 @@ class ComputeContextBase { return webgpu_context_.Run(*this, program); } - inline Status FlushAndWait() { - webgpu_context_.Flush(BufferManagerAccessor::Get(*this)); - return webgpu_context_.WaitForQueueIdle(); - } - protected: WebGpuContext& webgpu_context_; const WebGpuExecutionProvider& ep_; diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index a4654182d3b68..e4376476a885d 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -2,7 +2,6 @@ // Licensed under the MIT License. #include -#include #include "core/common/common.h" #include "core/common/logging/logging.h" @@ -14,72 +13,6 @@ namespace onnxruntime { namespace webgpu { -namespace { - -const char* CompilationMessageTypeToString(wgpu::CompilationMessageType type) { - switch (type) { - case wgpu::CompilationMessageType::Error: - return "error"; - case wgpu::CompilationMessageType::Warning: - return "warning"; - case wgpu::CompilationMessageType::Info: - return "info"; - default: - return "unknown"; - } -} - -std::string GetShaderCompilationDiagnostics(WebGpuContext& webgpu_context, const wgpu::ShaderModule& shader_module) { - struct CompilationInfoContext { - std::string diagnostics; - } compilation_info_context; - - auto future = shader_module.GetCompilationInfo( - wgpu::CallbackMode::WaitAnyOnly, - [](wgpu::CompilationInfoRequestStatus status, const wgpu::CompilationInfo* compilation_info, CompilationInfoContext* context) { - if (status != wgpu::CompilationInfoRequestStatus::Success) { - context->diagnostics = std::string{"Shader compilation info unavailable. Request status: "} + - (status == wgpu::CompilationInfoRequestStatus::CallbackCancelled ? "callback cancelled" : "unknown"); - return; - } - - if (compilation_info == nullptr || compilation_info->messageCount == 0 || compilation_info->messages == nullptr) { - return; - } - - std::string diagnostics; - diagnostics.reserve(compilation_info->messageCount * 96); - for (size_t i = 0; i < compilation_info->messageCount; ++i) { - const auto& message = compilation_info->messages[i]; - diagnostics += "\n ["; - diagnostics += CompilationMessageTypeToString(message.type); - diagnostics += "]"; - if (message.lineNum > 0) { - diagnostics += " line "; - diagnostics += std::to_string(message.lineNum); - if (message.linePos > 0) { - diagnostics += ':'; - diagnostics += std::to_string(message.linePos); - } - } - diagnostics += ": "; - diagnostics += std::string_view{message.message}; - } - - context->diagnostics = std::move(diagnostics); - }, - &compilation_info_context); - - const Status wait_status = webgpu_context.Wait(future); - if (!wait_status.IsOK()) { - return std::string{"Shader compilation info wait failed: "} + wait_status.ErrorMessage(); - } - - return compilation_info_context.diagnostics; -} - -} // namespace - ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniform_ranks) : name{program.Name()}, compute_pipeline{compute_pipeline}, @@ -264,7 +197,7 @@ Status ProgramManager::Build(const ProgramBase& program, struct CreateComputePipelineContext { wgpu::ComputePipeline& pipeline; - std::string error_message; + Status status; } create_pipeline_context{compute_pipeline, {}}; ORT_RETURN_IF_ERROR( @@ -276,23 +209,12 @@ Status ProgramManager::Build(const ProgramBase& program, if (status == wgpu::CreatePipelineAsyncStatus::Success) { context->pipeline = std::move(pipeline); } else { - context->error_message = "Failed to create a WebGPU compute pipeline: "; - context->error_message.append(message.data, message.length); + context->status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create a WebGPU compute pipeline: ", std::string_view{message}); } }, &create_pipeline_context))); - if (create_pipeline_context.error_message.empty()) { - return Status::OK(); - } - - const std::string compilation_diagnostics = GetShaderCompilationDiagnostics(webgpu_context_, shader_module); - if (compilation_diagnostics.empty()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, create_pipeline_context.error_message); - } - - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, create_pipeline_context.error_message, - "\nShader compilation diagnostics:", compilation_diagnostics); + return create_pipeline_context.status; } const ProgramArtifact* ProgramManager::Get(const std::string& key) const { diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 58e71de1fa211..ada9a2e8ab692 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -184,16 +184,6 @@ Status WebGpuContext::Wait(wgpu::Future f) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status)); } -Status WebGpuContext::WaitForQueueIdle() { - return Wait(device_queue_.OnSubmittedWorkDone( - wgpu::CallbackMode::WaitAnyOnly, - [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - ORT_ENFORCE(status == wgpu::QueueWorkDoneStatus::Success, - "Failed to wait for submitted WebGPU work: ", - std::string_view{message}); - })); -} - Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& program) { const auto& inputs = program.Inputs(); const auto& outputs = program.Outputs(); diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index fb6da131e45fe..021c7f383a6d7 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -165,7 +165,6 @@ class WebGpuContextFactory { class WebGpuContext final { public: Status Wait(wgpu::Future f); - Status WaitForQueueIdle(); const wgpu::Device& Device() const { return device_; } diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index ff1270938b639..d1cde04277938 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -578,6 +578,16 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, enable_int64_{config.enable_graph_capture || config.enable_int64}, multi_rotary_cache_concat_offset_{config.multi_rotary_cache_concat_offset}, prepack_allocator_{std::make_shared(context_.InitializerBufferManager(), false)} { + // If graph capture is enabled, create a dedicated buffer manager for graph mode + if (enable_graph_capture_) { + // Create buffer manager for graph capture mode with appropriate cache modes + graph_buffer_mgr_ = webgpu::BufferManagerFactory::Create( + context_, + webgpu::BufferCacheMode::Graph, + webgpu::BufferCacheMode::GraphSimple, + webgpu::BufferCacheMode::Disabled); + } + if (config.enable_pix_capture) { #if defined(ENABLE_PIX_FOR_WEBGPU_EP) // set pix frame generator @@ -593,7 +603,7 @@ std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { // allocator for initializers std::make_unique(context_.InitializerBufferManager(), true), // default allocator - std::make_unique([this]() -> const webgpu::BufferManager& { return BufferManager(); }, false), + std::make_unique(BufferManager(), false), }; } @@ -763,14 +773,6 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_op } if (graph_annotation_id != -1 && IsGraphCaptureAllowed() && !IsGraphCaptured(graph_annotation_id)) { - if (!graph_buffer_mgr_) { - graph_buffer_mgr_ = webgpu::BufferManagerFactory::Create( - context_, - webgpu::BufferCacheMode::Graph, - webgpu::BufferCacheMode::GraphSimple, - webgpu::BufferCacheMode::Disabled); - } - graph_buffer_mgr_active_ = true; context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_); } m_current_graph_annotation_id = graph_annotation_id; @@ -792,8 +794,6 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti } } - graph_buffer_mgr_active_ = false; - if (session_profiler_ && session_profiler_->Enabled()) { // Session-level profiling: collect into profiler's own events storage. context_.CollectProfilingData(session_profiler_->GpuEvents()); @@ -825,7 +825,6 @@ bool WebGpuExecutionProvider::IsGraphCaptured(int graph_annotation_id) const { Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) { ORT_ENFORCE(IsGraphCaptured(graph_annotation_id)); - ORT_ENFORCE(graph_buffer_mgr_ != nullptr, "Graph buffer manager must exist before replay."); // TODO: enable profiling in run level if (session_profiler_ && session_profiler_->Enabled()) { context_.StartProfiling(); @@ -839,7 +838,7 @@ Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) { } webgpu::BufferManager& WebGpuExecutionProvider::BufferManager() const { - if (graph_buffer_mgr_active_ && graph_buffer_mgr_) { + if (graph_buffer_mgr_) { return *graph_buffer_mgr_; } else { return context_.BufferManager(); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 69171df6a8a45..d1e2231dbba6f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -127,7 +127,6 @@ class WebGpuExecutionProvider : public IExecutionProvider { bool enable_int64_ = false; uint32_t multi_rotary_cache_concat_offset_ = 0; bool is_graph_captured_ = false; - bool graph_buffer_mgr_active_ = false; int regular_run_count_before_graph_capture_ = 0; const int min_num_runs_before_cuda_graph_capture_ = 1; // Required regular runs before graph capture for any necessary allocations. int m_current_graph_annotation_id = 0; diff --git a/onnxruntime/core/session/ort_version_check.h b/onnxruntime/core/session/ort_version_check.h index dea0d4366fbe2..f8fab0367b17d 100644 --- a/onnxruntime/core/session/ort_version_check.h +++ b/onnxruntime/core/session/ort_version_check.h @@ -10,27 +10,20 @@ namespace onnxruntime::version_check { -#if defined(__cpp_consteval) && __cpp_consteval >= 201811L -#define ORT_VERSION_CHECK_CONSTEVAL consteval -#else -#define ORT_VERSION_CHECK_CONSTEVAL constexpr -#endif - -// A simple consteval-friendly result type for ParseUint. -// std::optional triggers an internal compiler error in MSVC 14.44 when used with consteval. +// A simple constexpr-friendly result type for ParseUint. struct ParseUintResult { uint32_t value; bool has_value; - ORT_VERSION_CHECK_CONSTEVAL bool operator==(uint32_t other) const { return has_value && value == other; } - ORT_VERSION_CHECK_CONSTEVAL bool operator!=(uint32_t other) const { return !(*this == other); } + constexpr bool operator==(uint32_t other) const { return has_value && value == other; } + constexpr bool operator!=(uint32_t other) const { return !(*this == other); } }; -inline ORT_VERSION_CHECK_CONSTEVAL ParseUintResult ParseUintNone() { return {0, false}; } +inline constexpr ParseUintResult ParseUintNone() { return {0, false}; } // Parse a non-negative integer from a string_view without leading zeros. // Returns a result with has_value == false on failure (empty, leading zero, non-digit, or overflow). -ORT_VERSION_CHECK_CONSTEVAL ParseUintResult ParseUint(std::string_view str) { +constexpr ParseUintResult ParseUint(std::string_view str) { if (str.empty()) return ParseUintNone(); // Leading zeros are not allowed (except "0" itself). if (str.size() > 1 && str[0] == '0') return ParseUintNone(); @@ -48,7 +41,7 @@ ORT_VERSION_CHECK_CONSTEVAL ParseUintResult ParseUint(std::string_view str) { // - Major version is 1 // - Y and Z are non-negative integers without leading zeros // - Y (minor version) must equal expected_api_version (defaults to ORT_API_VERSION) -ORT_VERSION_CHECK_CONSTEVAL bool IsOrtVersionValid(std::string_view version, uint32_t expected_api_version = ORT_API_VERSION) { +constexpr bool IsOrtVersionValid(std::string_view version, uint32_t expected_api_version = ORT_API_VERSION) { size_t first_dot = version.find('.'); if (first_dot == std::string_view::npos) return false; size_t second_dot = version.find('.', first_dot + 1); @@ -71,6 +64,4 @@ ORT_VERSION_CHECK_CONSTEVAL bool IsOrtVersionValid(std::string_view version, uin return true; } -#undef ORT_VERSION_CHECK_CONSTEVAL - } // namespace onnxruntime::version_check diff --git a/onnxruntime/test/onnx/microbenchmark/main.cc b/onnxruntime/test/onnx/microbenchmark/main.cc index a2cb6aaff281a..b356dda740a31 100644 --- a/onnxruntime/test/onnx/microbenchmark/main.cc +++ b/onnxruntime/test/onnx/microbenchmark/main.cc @@ -79,7 +79,7 @@ int main(int argc, char** argv) { ::benchmark::Initialize(&argc, argv); if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return -1; - ORT_ABORT_ON_ERROR(g_ort->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env)); + ORT_ABORT_ON_ERROR(g_ort->CreateEnv(ORT_LOGGING_LEVEL_ERROR, "test", &env)); ::benchmark::RunSpecifiedBenchmarks(); g_ort->ReleaseEnv(env); return 0; diff --git a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc b/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc deleted file mode 100644 index a98c3f04d0eaf..0000000000000 --- a/onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc +++ /dev/null @@ -1,1673 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// This benchmark uses Dawn (WebGPU) headers directly, which are only available -// when the WebGPU EP is built (USE_WEBGPU is defined for that EP). On other -// builds the file compiles to an empty translation unit so the benchmark -// target can still link. -#ifdef USE_WEBGPU - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include "core/providers/webgpu/webgpu_provider_options.h" -#include -#include -#include "core/session/onnxruntime_run_options_config_keys.h" - -#include -#include -#include - -extern OrtEnv* env; -extern const OrtApi* g_ort; - -namespace { -constexpr const char* kDecodeBenchmarkModeEnvVar = "ORT_WEBGPU_MATMUL_NBITS_BENCHMARK_MODE"; -constexpr const char* kDecodeBenchmarkGraphCaptureEnvVar = "ORT_WEBGPU_MATMUL_NBITS_ENABLE_GRAPH_CAPTURE"; -constexpr const char* kDecodeBenchmarkOptimizedModelPathEnvVar = "ORT_WEBGPU_MATMUL_NBITS_OPTIMIZED_MODEL_PATH"; -constexpr const char* kDecodeBenchmarkVerboseSessionLogEnvVar = "ORT_WEBGPU_MATMUL_NBITS_VERBOSE_SESSION_LOG"; -constexpr float kDecodeCorrectnessAbsTolerance = 0.1f; -constexpr float kDecodeCorrectnessRelTolerance = 0.01f; -constexpr const char* kBenchmarkGraphCaptureAnnotationId = "1"; - -enum class DecodeBenchmarkMode { - kPerf, - kCorrectness, -}; - -bool IsGraphCaptureBenchmarkEnabled(); -bool IsGraphCaptureBenchmarkEnabled(); -bool IsVerboseSessionLogEnabled(); -std::string GetOptimizedModelPath(); - -enum class MlpDecodeBenchmarkVariant { - kUnfused, - kFused, -}; - -enum class MlpNormKind { - kSimplified, - kSkipSimplified, - kSkipSimplifiedPassthrough, -}; - -struct MlpDecodeBenchConfig { - int64_t n; - int64_t k; - int64_t bits; - int64_t block_size; - int64_t accuracy_level; -}; - -struct AdapterSelectionConfig { - // preferred_device_substring: optional case-insensitive device-name hint. - // context_id: ORT WebGPU custom context ID used to bind the externally created instance/device. - // backend_type: Dawn backend to enumerate adapters from, e.g. D3D12 or Vulkan. - // print_adapter_list: whether to print all discovered adapters before selecting one. - const char* preferred_device_substring; - int context_id; - WGPUBackendType backend_type; - bool print_adapter_list; -}; - -struct AdapterCandidate { - dawn::native::Adapter adapter; - int global_index; - WGPUAdapterType adapter_type; - int type_index; - uint32_t vendor_id; - uint32_t device_id; - std::string vendor; - std::string architecture; - std::string device; - std::string description; -}; - -struct SelectedWebGpuContext { - std::unique_ptr dawn_instance; - WGPUInstance instance{nullptr}; - WGPUDevice device{nullptr}; - std::unordered_map provider_options; - std::string selected_adapter_summary; -}; - -struct MlpTrafficStats { - double input_bytes; - double packed_weight_bytes; - double scale_bytes; - double intermediate_bytes; - double output_bytes; - double total_bytes; -}; - -struct QkvDecodeBenchConfig { - int64_t q_n; - int64_t kv_n; - int64_t k; - int64_t bits; - int64_t block_size; - int64_t accuracy_level; -}; - -enum class QkvDecodeBenchmarkVariant { - kUnfused, - kFused, -}; - -enum class QkvNormKind { - kSimplified, - kSkipSimplified, - kSkipSimplifiedPassthrough, -}; - -struct QkvTrafficStats { - double input_bytes; - double skip_input_bytes; - double norm_scale_bytes; - double packed_weight_bytes; - double scale_bytes; - double intermediate_bytes; - double output_bytes; - double total_bytes; -}; - -constexpr double kRtxTheoreticalBandwidthBytesPerSecond = 448.0 * 1000.0 * 1000.0 * 1000.0; -constexpr int kDecodeWarmupRuns = 25; - -DecodeBenchmarkMode GetDecodeBenchmarkMode() { - std::string mode_env = onnxruntime::Env::Default().GetEnvironmentVar(kDecodeBenchmarkModeEnvVar); - if (mode_env.empty()) { - return DecodeBenchmarkMode::kPerf; - } - - std::transform(mode_env.begin(), mode_env.end(), mode_env.begin(), - [](unsigned char value) { return static_cast(std::tolower(value)); }); - if (mode_env == "0" || mode_env == "false" || mode_env == "off" || - mode_env == "check" || mode_env == "correctness" || mode_env == "validate") { - return DecodeBenchmarkMode::kCorrectness; - } - - return DecodeBenchmarkMode::kPerf; -} - -bool IsDecodeBenchmarkPerfMode() { - return GetDecodeBenchmarkMode() == DecodeBenchmarkMode::kPerf; -} - -std::string GetDecodeBenchmarkLabel(const char* shape_label = nullptr) { - const char* mode_label = IsDecodeBenchmarkPerfMode() ? "perf" : "correctness"; - const char* graph_label = IsGraphCaptureBenchmarkEnabled() ? "graph_on" : "graph_off"; - - std::ostringstream stream; - stream << "fp16_decode"; - if (shape_label != nullptr && shape_label[0] != '\0') { - stream << '_' << shape_label; - } - stream << '_' << mode_label << "_auto_gpu_" << graph_label; - return stream.str(); -} - -bool IsGraphCaptureBenchmarkEnabled() { - std::string graph_capture_env = onnxruntime::Env::Default().GetEnvironmentVar(kDecodeBenchmarkGraphCaptureEnvVar); - if (graph_capture_env.empty()) { - return false; - } - - std::transform(graph_capture_env.begin(), graph_capture_env.end(), graph_capture_env.begin(), - [](unsigned char value) { return static_cast(std::tolower(value)); }); - return graph_capture_env != "0" && graph_capture_env != "false" && graph_capture_env != "off"; -} - -bool IsVerboseSessionLogEnabled() { - std::string verbose_log_env = onnxruntime::Env::Default().GetEnvironmentVar(kDecodeBenchmarkVerboseSessionLogEnvVar); - if (verbose_log_env.empty()) { - return false; - } - - std::transform(verbose_log_env.begin(), verbose_log_env.end(), verbose_log_env.begin(), - [](unsigned char value) { return static_cast(std::tolower(value)); }); - return verbose_log_env != "0" && verbose_log_env != "false" && verbose_log_env != "off"; -} - -std::string GetOptimizedModelPath() { - return onnxruntime::Env::Default().GetEnvironmentVar(kDecodeBenchmarkOptimizedModelPathEnvVar); -} - -Ort::RunOptions CreateBenchmarkRunOptions() { - Ort::RunOptions run_options; - if (IsGraphCaptureBenchmarkEnabled()) { - run_options.AddConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation, kBenchmarkGraphCaptureAnnotationId); - } - - return run_options; -} - -std::vector GetRequiredDeviceFeatures(const wgpu::Adapter& adapter) { - std::vector required_features; - constexpr wgpu::FeatureName features[]{ -#if !defined(__wasm__) - wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses, - wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix, -#endif - wgpu::FeatureName::TimestampQuery, - wgpu::FeatureName::ShaderF16, - wgpu::FeatureName::Subgroups, -#if !defined(__wasm__) - wgpu::FeatureName::BufferMapExtendedUsages, -#endif - }; - for (auto feature : features) { - if (adapter.HasFeature(feature)) { - required_features.push_back(feature); - } - } - return required_features; -} - -wgpu::Limits GetRequiredDeviceLimits(const wgpu::Adapter& adapter) { - wgpu::Limits required_limits{}; - wgpu::Limits adapter_limits{}; - if (!adapter.GetLimits(&adapter_limits)) { - throw std::runtime_error("Failed to query adapter limits for the selected WebGPU adapter."); - } - - required_limits.maxBindGroups = adapter_limits.maxBindGroups; - required_limits.maxComputeWorkgroupStorageSize = adapter_limits.maxComputeWorkgroupStorageSize; - required_limits.maxComputeWorkgroupsPerDimension = adapter_limits.maxComputeWorkgroupsPerDimension; - required_limits.maxStorageBuffersPerShaderStage = adapter_limits.maxStorageBuffersPerShaderStage; - required_limits.maxStorageBufferBindingSize = adapter_limits.maxStorageBufferBindingSize; - required_limits.maxBufferSize = adapter_limits.maxBufferSize; - required_limits.maxComputeInvocationsPerWorkgroup = adapter_limits.maxComputeInvocationsPerWorkgroup; - required_limits.maxComputeWorkgroupSizeX = adapter_limits.maxComputeWorkgroupSizeX; - required_limits.maxComputeWorkgroupSizeY = adapter_limits.maxComputeWorkgroupSizeY; - required_limits.maxComputeWorkgroupSizeZ = adapter_limits.maxComputeWorkgroupSizeZ; - - return required_limits; -} - -std::string ToString(WGPUStringView value) { - return value.data == nullptr ? std::string{} : std::string(value.data, value.length); -} - -const char* AdapterTypeToString(WGPUAdapterType adapter_type) { - switch (adapter_type) { - case WGPUAdapterType_DiscreteGPU: - return "discrete"; - case WGPUAdapterType_IntegratedGPU: - return "integrated"; - case WGPUAdapterType_CPU: - return "cpu"; - default: - return "unknown"; - } -} - -bool IsGpuAdapterType(WGPUAdapterType adapter_type) { - return adapter_type == WGPUAdapterType_DiscreteGPU || - adapter_type == WGPUAdapterType_IntegratedGPU; -} - -std::string FormatAdapterSummary(const AdapterCandidate& adapter) { - std::ostringstream stream; - stream << "adapter[" << adapter.global_index << "]" - << " type=" << AdapterTypeToString(adapter.adapter_type) - << " type_index=" << adapter.type_index - << " vendor=" << adapter.vendor - << " architecture=" << adapter.architecture - << " gpu_name=" << adapter.device - << " description=" << adapter.description - << " vendor_id=" << adapter.vendor_id - << " device_id=" << adapter.device_id; - return stream.str(); -} - -std::string FormatFeatureSupport(const dawn::native::Adapter& adapter) { - const wgpu::Adapter wgpu_adapter = adapter.Get(); - std::ostringstream stream; - stream << "shader_f16=" << (wgpu_adapter.HasFeature(wgpu::FeatureName::ShaderF16) ? "yes" : "no") - << " subgroups=" << (wgpu_adapter.HasFeature(wgpu::FeatureName::Subgroups) ? "yes" : "no") - << " timestamp_query=" << (wgpu_adapter.HasFeature(wgpu::FeatureName::TimestampQuery) ? "yes" : "no"); -#if !defined(__wasm__) - stream << " subgroup_matrix=" << (wgpu_adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix) ? "yes" : "no") - << " buffer_map_extended_usages=" << (wgpu_adapter.HasFeature(wgpu::FeatureName::BufferMapExtendedUsages) ? "yes" : "no") - << " timestamp_query_inside_passes=" << (wgpu_adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses) ? "yes" : "no"); -#endif - return stream.str(); -} - -std::string ToLower(std::string value) { - std::transform(value.begin(), value.end(), value.begin(), - [](unsigned char character) { return static_cast(std::tolower(character)); }); - return value; -} - -MlpTrafficStats CalculateMlpTrafficStats(const MlpDecodeBenchConfig& config, - MlpDecodeBenchmarkVariant variant, - MlpNormKind norm_kind) { - const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; - const int64_t blob_size = (config.block_size * config.bits) / 8; - - const bool is_unfused = variant == MlpDecodeBenchmarkVariant::kUnfused; - const double input_reads = variant == MlpDecodeBenchmarkVariant::kUnfused ? 2.0 : 1.0; - const bool has_skip = norm_kind == MlpNormKind::kSkipSimplified || - norm_kind == MlpNormKind::kSkipSimplifiedPassthrough; - const double skip_input_bytes = has_skip ? static_cast(config.k) * sizeof(Ort::Float16_t) : 0.0; - const double norm_scale_bytes = static_cast(config.k) * sizeof(Ort::Float16_t); - const double intermediate_bytes = is_unfused ? 4.0 * static_cast(config.n) * sizeof(Ort::Float16_t) : 0.0; - const double input_bytes = input_reads * static_cast(config.k) * sizeof(Ort::Float16_t); - const double packed_weight_bytes = - 2.0 * static_cast(config.n) * static_cast(k_blocks) * static_cast(blob_size); - const double scale_bytes = 2.0 * static_cast(config.n) * static_cast(k_blocks) * sizeof(Ort::Float16_t); - const double output_bytes = - static_cast(config.n + (norm_kind == MlpNormKind::kSkipSimplifiedPassthrough ? config.k : 0)) * - sizeof(Ort::Float16_t); - - return { - input_bytes, - packed_weight_bytes, - scale_bytes, - intermediate_bytes, - output_bytes, - input_bytes + skip_input_bytes + norm_scale_bytes + packed_weight_bytes + scale_bytes + intermediate_bytes + output_bytes, - }; -} - -QkvTrafficStats CalculateQkvTrafficStats(const QkvDecodeBenchConfig& config, - QkvDecodeBenchmarkVariant variant, - QkvNormKind norm_kind) { - const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; - const int64_t blob_size = (config.block_size * config.bits) / 8; - - const bool has_skip = norm_kind == QkvNormKind::kSkipSimplified || - norm_kind == QkvNormKind::kSkipSimplifiedPassthrough; - const bool has_skip_passthrough = norm_kind == QkvNormKind::kSkipSimplifiedPassthrough; - - const double input_bytes = static_cast(config.k) * sizeof(Ort::Float16_t); - const double skip_input_bytes = has_skip ? static_cast(config.k) * sizeof(Ort::Float16_t) : 0.0; - const double norm_scale_bytes = static_cast(config.k) * sizeof(Ort::Float16_t); - const double packed_weight_bytes = - static_cast(config.q_n + 2 * config.kv_n) * static_cast(k_blocks) * static_cast(blob_size); - const double scale_bytes = - static_cast(config.q_n + 2 * config.kv_n) * static_cast(k_blocks) * sizeof(Ort::Float16_t); - const double intermediate_bytes = - variant == QkvDecodeBenchmarkVariant::kUnfused ? static_cast(config.k) * sizeof(Ort::Float16_t) : 0.0; - const double output_bytes = - static_cast(config.q_n + 2 * config.kv_n + (has_skip_passthrough ? config.k : 0)) * - sizeof(Ort::Float16_t); - - return { - input_bytes, - skip_input_bytes, - norm_scale_bytes, - packed_weight_bytes, - scale_bytes, - intermediate_bytes, - output_bytes, - input_bytes + skip_input_bytes + norm_scale_bytes + packed_weight_bytes + scale_bytes + intermediate_bytes + output_bytes, - }; -} - -AdapterSelectionConfig GetAdapterSelectionConfig() { - // Prefer a 5060 Ti when Dawn exposes one, otherwise fall back to the first - // Dawn-enumerated adapter so the benchmark remains robust across machines. - return { - "5060 Ti", // preferred_device_substring - 1, // context_id - WGPUBackendType_D3D12, // backend_type - true, // print_adapter_list - }; -} - -SelectedWebGpuContext CreateSelectedWebGpuContext() { - const AdapterSelectionConfig config = GetAdapterSelectionConfig(); - - wgpu::InstanceFeatureName required_instance_features[] = {wgpu::InstanceFeatureName::TimedWaitAny}; - wgpu::InstanceDescriptor instance_desc{}; - instance_desc.requiredFeatures = required_instance_features; - instance_desc.requiredFeatureCount = sizeof(required_instance_features) / sizeof(required_instance_features[0]); - - SelectedWebGpuContext selected_context; - selected_context.dawn_instance = std::make_unique(&instance_desc); - -#if !defined(BUILD_DAWN_SHARED_LIBRARY) - static std::once_flag dawn_procs_initialized; - std::call_once(dawn_procs_initialized, []() { - dawnProcSetProcs(&dawn::native::GetProcs()); - }); -#endif - - WGPURequestAdapterOptions adapter_options = WGPU_REQUEST_ADAPTER_OPTIONS_INIT; - adapter_options.backendType = config.backend_type; - adapter_options.powerPreference = WGPUPowerPreference_Undefined; - - std::vector adapters = selected_context.dawn_instance->EnumerateAdapters(&adapter_options); - if (adapters.empty()) { - throw std::runtime_error("No Dawn adapters were found for the configured backend."); - } - - std::vector candidates; - candidates.reserve(adapters.size()); - int discrete_index = 0; - int integrated_index = 0; - int cpu_index = 0; - int unknown_index = 0; - for (size_t i = 0; i < adapters.size(); ++i) { - WGPUAdapterInfo info = WGPU_ADAPTER_INFO_INIT; - if (wgpuAdapterGetInfo(adapters[i].Get(), &info) != WGPUStatus_Success) { - continue; - } - - const WGPUAdapterType adapter_type = info.adapterType; - int current_type_index = 0; - switch (adapter_type) { - case WGPUAdapterType_DiscreteGPU: - current_type_index = discrete_index++; - break; - case WGPUAdapterType_IntegratedGPU: - current_type_index = integrated_index++; - break; - case WGPUAdapterType_CPU: - current_type_index = cpu_index++; - break; - default: - current_type_index = unknown_index++; - break; - } - candidates.push_back(AdapterCandidate{ - adapters[i], - static_cast(i), - adapter_type, - current_type_index, - info.vendorID, - info.deviceID, - ToString(info.vendor), - ToString(info.architecture), - ToString(info.device), - ToString(info.description), - }); - - wgpuAdapterInfoFreeMembers(info); - } - - if (config.print_adapter_list) { - std::cout << "Available Dawn GPU adapters for WebGPU benchmark:" << std::endl; - bool printed_gpu = false; - for (const auto& candidate : candidates) { - if (!IsGpuAdapterType(candidate.adapter_type)) { - continue; - } - - printed_gpu = true; - std::cout << " " << FormatAdapterSummary(candidate) - << " features={" << FormatFeatureSupport(candidate.adapter) << "}" - << std::endl; - } - - if (!printed_gpu) { - std::cout << " No integrated or discrete GPU adapters were found." << std::endl; - } - } - - AdapterCandidate* selected_adapter = nullptr; - if (config.preferred_device_substring != nullptr) { - const std::string preferred_substring = ToLower(config.preferred_device_substring); - for (auto& candidate : candidates) { - if (ToLower(candidate.device).find(preferred_substring) != std::string::npos) { - selected_adapter = &candidate; - break; - } - } - } - - if (selected_adapter == nullptr && !candidates.empty()) { - selected_adapter = &candidates.front(); - } - - if (selected_adapter == nullptr) { - throw std::runtime_error("No Dawn adapter candidates were available for WebGPU benchmark selection."); - } - - const wgpu::Adapter adapter = selected_adapter->adapter.Get(); - std::vector required_features = GetRequiredDeviceFeatures(adapter); - wgpu::Limits required_limits = GetRequiredDeviceLimits(adapter); - wgpu::DeviceDescriptor device_desc{}; - if (!required_features.empty()) { - device_desc.requiredFeatures = required_features.data(); - device_desc.requiredFeatureCount = required_features.size(); - } - device_desc.requiredLimits = &required_limits; - - selected_context.instance = selected_context.dawn_instance->Get(); - selected_context.device = selected_adapter->adapter.CreateDevice(&device_desc); - if (selected_context.device == nullptr) { - throw std::runtime_error("Failed to create a WGPUDevice for the selected adapter."); - } - - selected_context.selected_adapter_summary = FormatAdapterSummary(*selected_adapter); - std::cout << "Selected Dawn adapter for WebGPU benchmark: " - << selected_context.selected_adapter_summary - << " features={" << FormatFeatureSupport(selected_adapter->adapter) << "}" - << std::endl; - - selected_context.provider_options["deviceId"] = std::to_string(config.context_id); - selected_context.provider_options["webgpuInstance"] = std::to_string(reinterpret_cast(selected_context.instance)); - selected_context.provider_options["webgpuDevice"] = std::to_string(reinterpret_cast(selected_context.device)); - selected_context.provider_options["preserveDevice"] = "1"; - selected_context.provider_options["dawnProcTable"] = std::to_string(reinterpret_cast(&dawn::native::GetProcs())); - - return selected_context; -} - -const SelectedWebGpuContext& GetSelectedWebGpuContext() { - static const SelectedWebGpuContext selected_context = CreateSelectedWebGpuContext(); - return selected_context; -} - -template -void AddTensorInitializer(ONNX_NAMESPACE::GraphProto& graph, - const std::string& name, - int32_t data_type, - const std::vector& dims, - const std::vector& values) { - auto* initializer = graph.add_initializer(); - initializer->set_name(name); - initializer->set_data_type(data_type); - for (int64_t dim : dims) { - initializer->add_dims(dim); - } - - initializer->set_raw_data(values.data(), values.size() * sizeof(T)); -} - -void AddTensorValueInfo(ONNX_NAMESPACE::GraphProto& graph, - const std::string& name, - int32_t data_type, - const std::vector& dims) { - auto* value_info = graph.add_value_info(); - value_info->set_name(name); - value_info->mutable_type()->mutable_tensor_type()->set_elem_type(data_type); - auto* shape = value_info->mutable_type()->mutable_tensor_type()->mutable_shape(); - for (int64_t dim : dims) { - shape->add_dim()->set_dim_value(dim); - } -} - -std::vector GetMlpDecodeBenchConfigs() { - // Qwen3-1.7B MLP gate/up decode geometry: hidden=2048, intermediate=6144. - return { - {6144, 2048, 4, 32, 4}, - }; -} - -std::vector GetQkvDecodeBenchConfigs() { - // Qwen3-1.7B attention projection geometry: hidden=2048, q=2048, kv=1024. - return { - {2048, 1024, 2048, 4, 32, 4}, - }; -} - -void AddMatMulNBitsNode(ONNX_NAMESPACE::GraphProto& graph, - const std::string& node_name, - const std::string& input_name, - const std::string& weight_name, - const std::string& scale_name, - const std::string& output_name, - int64_t k, - int64_t n, - int64_t bits, - int64_t block_size, - int64_t accuracy_level) { - auto* node = graph.add_node(); - node->set_name(node_name); - node->set_op_type("MatMulNBits"); - node->set_domain("com.microsoft"); - node->add_input(input_name); - node->add_input(weight_name); - node->add_input(scale_name); - node->add_input(""); - node->add_input(""); - node->add_output(output_name); - - auto* attr_k = node->add_attribute(); - attr_k->set_name("K"); - attr_k->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); - attr_k->set_i(k); - - auto* attr_n = node->add_attribute(); - attr_n->set_name("N"); - attr_n->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); - attr_n->set_i(n); - - auto* attr_bits = node->add_attribute(); - attr_bits->set_name("bits"); - attr_bits->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); - attr_bits->set_i(bits); - - auto* attr_block = node->add_attribute(); - attr_block->set_name("block_size"); - attr_block->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); - attr_block->set_i(block_size); - - auto* attr_accuracy = node->add_attribute(); - attr_accuracy->set_name("accuracy_level"); - attr_accuracy->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); - attr_accuracy->set_i(accuracy_level); -} - -void AddMatMulNBitsMlpNode(ONNX_NAMESPACE::GraphProto& graph, - const std::string& node_name, - const std::string& input_name, - const std::string& skip_input_name, - const std::string& norm_scale_name, - const std::string& gate_weight_name, - const std::string& gate_scale_name, - const std::string& up_weight_name, - const std::string& up_scale_name, - const std::string& output_name, - const std::string& skip_sum_output_name, - int64_t k, - int64_t n, - int64_t bits, - int64_t block_size, - int64_t accuracy_level) { - auto* node = graph.add_node(); - node->set_name(node_name); - node->set_op_type("MatMulNBitsMlp"); - node->set_domain("com.microsoft"); - node->add_input(input_name); - node->add_input(skip_input_name); - node->add_input(norm_scale_name); - node->add_input(gate_weight_name); - node->add_input(gate_scale_name); - node->add_input(""); - node->add_input(up_weight_name); - node->add_input(up_scale_name); - node->add_input(""); - node->add_output(output_name); - if (!skip_sum_output_name.empty()) { - node->add_output(skip_sum_output_name); - } - - auto* attr_k = node->add_attribute(); - attr_k->set_name("K"); - attr_k->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); - attr_k->set_i(k); - - auto* attr_n = node->add_attribute(); - attr_n->set_name("N"); - attr_n->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); - attr_n->set_i(n); - - auto* attr_bits = node->add_attribute(); - attr_bits->set_name("bits"); - attr_bits->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); - attr_bits->set_i(bits); - - auto* attr_block = node->add_attribute(); - attr_block->set_name("block_size"); - attr_block->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); - attr_block->set_i(block_size); - - auto* attr_accuracy = node->add_attribute(); - attr_accuracy->set_name("accuracy_level"); - attr_accuracy->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); - attr_accuracy->set_i(accuracy_level); - - auto* attr_activation = node->add_attribute(); - attr_activation->set_name("activation"); - attr_activation->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRING); - attr_activation->set_s("silu"); -} - -void AddMatMulNBitsQkvNode(ONNX_NAMESPACE::GraphProto& graph, - const std::string& node_name, - const std::string& input_name, - const std::string& skip_input_name, - const std::string& norm_scale_name, - const std::string& q_weight_name, - const std::string& q_scale_name, - const std::string& k_weight_name, - const std::string& k_scale_name, - const std::string& v_weight_name, - const std::string& v_scale_name, - const std::string& q_output_name, - const std::string& k_output_name, - const std::string& v_output_name, - const std::string& skip_sum_output_name, - int64_t k, - int64_t q_n, - int64_t kv_n, - int64_t bits, - int64_t block_size, - int64_t accuracy_level, - float epsilon) { - auto* node = graph.add_node(); - node->set_name(node_name); - node->set_op_type("MatMulNBitsQkv"); - node->set_domain("com.microsoft"); - node->add_input(input_name); - node->add_input(skip_input_name); - node->add_input(norm_scale_name); - node->add_input(q_weight_name); - node->add_input(q_scale_name); - node->add_input(k_weight_name); - node->add_input(k_scale_name); - node->add_input(v_weight_name); - node->add_input(v_scale_name); - node->add_output(q_output_name); - node->add_output(k_output_name); - node->add_output(v_output_name); - if (!skip_sum_output_name.empty()) { - node->add_output(skip_sum_output_name); - } - - auto* attr_k = node->add_attribute(); - attr_k->set_name("K"); - attr_k->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); - attr_k->set_i(k); - auto* attr_qn = node->add_attribute(); - attr_qn->set_name("Nq"); - attr_qn->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); - attr_qn->set_i(q_n); - auto* attr_kvn = node->add_attribute(); - attr_kvn->set_name("Nkv"); - attr_kvn->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); - attr_kvn->set_i(kv_n); - auto* attr_bits = node->add_attribute(); - attr_bits->set_name("bits"); - attr_bits->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); - attr_bits->set_i(bits); - auto* attr_block = node->add_attribute(); - attr_block->set_name("block_size"); - attr_block->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); - attr_block->set_i(block_size); - auto* attr_accuracy = node->add_attribute(); - attr_accuracy->set_name("accuracy_level"); - attr_accuracy->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); - attr_accuracy->set_i(accuracy_level); - auto* attr_epsilon = node->add_attribute(); - attr_epsilon->set_name("epsilon"); - attr_epsilon->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT); - attr_epsilon->set_f(epsilon); -} - -std::string GetMlpVariantLabel(MlpDecodeBenchmarkVariant variant) { - switch (variant) { - case MlpDecodeBenchmarkVariant::kUnfused: - return "unfused"; - case MlpDecodeBenchmarkVariant::kFused: - return "fused"; - } - - return "unknown"; -} - -std::string GetMlpNormKindLabel(MlpNormKind norm_kind) { - switch (norm_kind) { - case MlpNormKind::kSimplified: - return "simplified"; - case MlpNormKind::kSkipSimplified: - return "skip_simplified"; - case MlpNormKind::kSkipSimplifiedPassthrough: - return "skip_simplified_passthrough"; - } - - return "unknown"; -} - -std::string GetMlpDecodeBenchmarkLabel(MlpDecodeBenchmarkVariant variant, MlpNormKind norm_kind) { - std::ostringstream stream; - stream << "fp16_mlp_decode_" << GetMlpNormKindLabel(norm_kind) << '_' << GetMlpVariantLabel(variant) << '_' - << (IsDecodeBenchmarkPerfMode() ? "perf" : "correctness") << '_' - << "auto_gpu_" - << (IsGraphCaptureBenchmarkEnabled() ? "graph_on" : "graph_off"); - return stream.str(); -} - -std::string GetQkvVariantLabel(QkvDecodeBenchmarkVariant variant) { - return variant == QkvDecodeBenchmarkVariant::kFused ? "fused" : "unfused"; -} - -std::string GetQkvNormKindLabel(QkvNormKind norm_kind) { - switch (norm_kind) { - case QkvNormKind::kSimplified: - return "simplified"; - case QkvNormKind::kSkipSimplified: - return "skip_simplified"; - case QkvNormKind::kSkipSimplifiedPassthrough: - return "skip_simplified_passthrough"; - } - - return "unknown"; -} - -std::string GetQkvDecodeBenchmarkLabel(QkvDecodeBenchmarkVariant variant, QkvNormKind norm_kind) { - std::ostringstream stream; - stream << "fp16_qkv_norm_" << GetQkvNormKindLabel(norm_kind) << '_' << GetQkvVariantLabel(variant) << '_' - << (IsDecodeBenchmarkPerfMode() ? "perf" : "correctness") << '_' - << "auto_gpu_" - << (IsGraphCaptureBenchmarkEnabled() ? "graph_on" : "graph_off"); - return stream.str(); -} - -std::vector SerializeMatMulNBitsMlpModel(const MlpDecodeBenchConfig& config, - MlpDecodeBenchmarkVariant variant, - MlpNormKind norm_kind) { - const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; - const int64_t blob_size = (config.block_size * config.bits) / 8; - - ONNX_NAMESPACE::ModelProto model; - model.set_ir_version(10); - - auto* onnx_opset = model.add_opset_import(); - onnx_opset->set_domain(""); - onnx_opset->set_version(21); - auto* ms_opset = model.add_opset_import(); - ms_opset->set_domain("com.microsoft"); - ms_opset->set_version(1); - - auto* graph = model.mutable_graph(); - switch (variant) { - case MlpDecodeBenchmarkVariant::kFused: - graph->set_name("WebGpuMatMulNBitsMlpDecodeFused"); - break; - case MlpDecodeBenchmarkVariant::kUnfused: - default: - graph->set_name("WebGpuMatMulNBitsMlpDecodeUnfused"); - break; - } - - const bool has_skip = norm_kind == MlpNormKind::kSkipSimplified || - norm_kind == MlpNormKind::kSkipSimplifiedPassthrough; - const bool has_skip_passthrough = norm_kind == MlpNormKind::kSkipSimplifiedPassthrough; - - auto* input = graph->add_input(); - input->set_name("A"); - input->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); - input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); - input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.k); - - if (has_skip) { - auto* skip_input = graph->add_input(); - skip_input->set_name("Skip"); - skip_input->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); - skip_input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); - skip_input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.k); - } - - auto* output = graph->add_output(); - output->set_name("Y"); - output->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); - output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); - output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.n); - if (has_skip_passthrough) { - auto* skip_sum_output = graph->add_output(); - skip_sum_output->set_name("SkipSum"); - skip_sum_output->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); - skip_sum_output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); - skip_sum_output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.k); - } - - std::vector gate_b(static_cast(config.n * k_blocks * blob_size), uint8_t{0x11}); - std::vector up_b(static_cast(config.n * k_blocks * blob_size), uint8_t{0x77}); - std::vector gate_scales(static_cast(config.n * k_blocks), Ort::Float16_t(0.03125f)); - std::vector up_scales(static_cast(config.n * k_blocks), Ort::Float16_t(0.0625f)); - std::vector norm_scale(static_cast(config.k), Ort::Float16_t(1.0f)); - AddTensorInitializer(*graph, "gate_B", ONNX_NAMESPACE::TensorProto_DataType_UINT8, - {config.n, k_blocks, blob_size}, gate_b); - AddTensorInitializer(*graph, "up_B", ONNX_NAMESPACE::TensorProto_DataType_UINT8, - {config.n, k_blocks, blob_size}, up_b); - AddTensorInitializer(*graph, "gate_scales", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - {config.n, k_blocks}, gate_scales); - AddTensorInitializer(*graph, "up_scales", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - {config.n, k_blocks}, up_scales); - AddTensorInitializer(*graph, "norm_scale", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - {config.k}, norm_scale); - - if (variant == MlpDecodeBenchmarkVariant::kFused) { - AddMatMulNBitsMlpNode(*graph, - "MatMulNBitsMlpDecode", - "A", - has_skip ? "Skip" : "", - "norm_scale", - "gate_B", - "gate_scales", - "up_B", - "up_scales", - "Y", - has_skip_passthrough ? "SkipSum" : "", - config.k, - config.n, - config.bits, - config.block_size, - config.accuracy_level); - } else { - const char* mlp_input_name = "A_norm"; - AddTensorValueInfo(*graph, "A_norm", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.k}); - auto* norm = graph->add_node(); - norm->set_name(has_skip ? "InputSkipSimplifiedLayerNorm" : "InputSimplifiedLayerNorm"); - norm->set_op_type(has_skip ? "SkipSimplifiedLayerNormalization" : "SimplifiedLayerNormalization"); - if (has_skip) { - norm->set_domain("com.microsoft"); - norm->add_input("A"); - norm->add_input("Skip"); - norm->add_input("norm_scale"); - norm->add_output("A_norm"); - if (has_skip_passthrough) { - norm->add_output(""); - norm->add_output(""); - norm->add_output("SkipSum"); - } - } else { - norm->add_input("A"); - norm->add_input("norm_scale"); - norm->add_output("A_norm"); - } - auto* attr_epsilon = norm->add_attribute(); - attr_epsilon->set_name("epsilon"); - attr_epsilon->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT); - attr_epsilon->set_f(1e-6f); - - AddTensorValueInfo(*graph, "gate_mm", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.n}); - AddTensorValueInfo(*graph, "up_mm", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.n}); - AddTensorValueInfo(*graph, "gate_sigmoid", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.n}); - AddTensorValueInfo(*graph, "gate_silu", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.n}); - - AddMatMulNBitsNode(*graph, - "GateMatMulNBitsDecode", - mlp_input_name, - "gate_B", - "gate_scales", - "gate_mm", - config.k, - config.n, - config.bits, - config.block_size, - config.accuracy_level); - AddMatMulNBitsNode(*graph, - "UpMatMulNBitsDecode", - mlp_input_name, - "up_B", - "up_scales", - "up_mm", - config.k, - config.n, - config.bits, - config.block_size, - config.accuracy_level); - - auto* sigmoid = graph->add_node(); - sigmoid->set_name("GateSigmoid"); - sigmoid->set_op_type("Sigmoid"); - sigmoid->add_input("gate_mm"); - sigmoid->add_output("gate_sigmoid"); - - auto* silu_mul = graph->add_node(); - silu_mul->set_name("GateSiluMul"); - silu_mul->set_op_type("Mul"); - silu_mul->add_input("gate_mm"); - silu_mul->add_input("gate_sigmoid"); - silu_mul->add_output("gate_silu"); - - auto* output_mul = graph->add_node(); - output_mul->set_name("GateUpMul"); - output_mul->set_op_type("Mul"); - output_mul->add_input("gate_silu"); - output_mul->add_input("up_mm"); - output_mul->add_output("Y"); - } - - const auto serialized = model.SerializeAsString(); - return std::vector(serialized.begin(), serialized.end()); -} - -std::vector SerializeMatMulNBitsQkvModel(const QkvDecodeBenchConfig& config, - QkvDecodeBenchmarkVariant variant, - QkvNormKind norm_kind) { - const int64_t k_blocks = (config.k + config.block_size - 1) / config.block_size; - const int64_t blob_size = (config.block_size * config.bits) / 8; - - ONNX_NAMESPACE::ModelProto model; - model.set_ir_version(10); - - auto* onnx_opset = model.add_opset_import(); - onnx_opset->set_domain(""); - onnx_opset->set_version(21); - auto* ms_opset = model.add_opset_import(); - ms_opset->set_domain("com.microsoft"); - ms_opset->set_version(1); - - auto* graph = model.mutable_graph(); - graph->set_name(variant == QkvDecodeBenchmarkVariant::kFused - ? (norm_kind == QkvNormKind::kSkipSimplified ? "WebGpuMatMulNBitsQkvSkipNormFused" : "WebGpuMatMulNBitsQkvSimplifiedNormFused") - : (norm_kind == QkvNormKind::kSkipSimplified ? "WebGpuMatMulNBitsQkvSkipNormUnfused" : "WebGpuMatMulNBitsQkvSimplifiedNormUnfused")); - - const bool has_skip = norm_kind == QkvNormKind::kSkipSimplified || - norm_kind == QkvNormKind::kSkipSimplifiedPassthrough; - const bool has_skip_passthrough = norm_kind == QkvNormKind::kSkipSimplifiedPassthrough; - - auto* input = graph->add_input(); - input->set_name("A"); - input->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); - input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); - input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.k); - - if (has_skip) { - auto* skip_input = graph->add_input(); - skip_input->set_name("Skip"); - skip_input->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); - skip_input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); - skip_input->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(config.k); - } - - auto add_output = [&](const std::string& name, int64_t n) { - auto* output = graph->add_output(); - output->set_name(name); - output->mutable_type()->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); - output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); - output->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(n); - }; - add_output("Q", config.q_n); - add_output("K", config.kv_n); - add_output("V", config.kv_n); - if (has_skip_passthrough) { - add_output("SkipSum", config.k); - } - - std::vector norm_scale(static_cast(config.k), Ort::Float16_t(1.0f)); - std::vector q_b(static_cast(config.q_n * k_blocks * blob_size), uint8_t{0x11}); - std::vector k_b(static_cast(config.kv_n * k_blocks * blob_size), uint8_t{0x33}); - std::vector v_b(static_cast(config.kv_n * k_blocks * blob_size), uint8_t{0x77}); - std::vector q_scales(static_cast(config.q_n * k_blocks), Ort::Float16_t(0.03125f)); - std::vector k_scales(static_cast(config.kv_n * k_blocks), Ort::Float16_t(0.03125f)); - std::vector v_scales(static_cast(config.kv_n * k_blocks), Ort::Float16_t(0.0625f)); - - AddTensorInitializer(*graph, "norm_scale", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {config.k}, norm_scale); - AddTensorInitializer(*graph, "q_B", ONNX_NAMESPACE::TensorProto_DataType_UINT8, {config.q_n, k_blocks, blob_size}, q_b); - AddTensorInitializer(*graph, "q_scales", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {config.q_n, k_blocks}, q_scales); - AddTensorInitializer(*graph, "k_B", ONNX_NAMESPACE::TensorProto_DataType_UINT8, {config.kv_n, k_blocks, blob_size}, k_b); - AddTensorInitializer(*graph, "k_scales", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {config.kv_n, k_blocks}, k_scales); - AddTensorInitializer(*graph, "v_B", ONNX_NAMESPACE::TensorProto_DataType_UINT8, {config.kv_n, k_blocks, blob_size}, v_b); - AddTensorInitializer(*graph, "v_scales", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {config.kv_n, k_blocks}, v_scales); - - if (variant == QkvDecodeBenchmarkVariant::kFused) { - AddMatMulNBitsQkvNode(*graph, - "MatMulNBitsQkvDecode", - "A", - has_skip ? "Skip" : "", - "norm_scale", - "q_B", - "q_scales", - "k_B", - "k_scales", - "v_B", - "v_scales", - "Q", - "K", - "V", - has_skip_passthrough ? "SkipSum" : "", - config.k, - config.q_n, - config.kv_n, - config.bits, - config.block_size, - config.accuracy_level, - 1e-6f); - } else { - AddTensorValueInfo(*graph, "A_norm", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.k}); - auto* norm = graph->add_node(); - norm->set_name(has_skip ? "InputSkipSimplifiedLayerNorm" : "InputSimplifiedLayerNorm"); - norm->set_op_type(has_skip ? "SkipSimplifiedLayerNormalization" : "SimplifiedLayerNormalization"); - if (has_skip) { - if (has_skip_passthrough) { - AddTensorValueInfo(*graph, "SkipSum", ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, {1, config.k}); - } - norm->set_domain("com.microsoft"); - norm->add_input("A"); - norm->add_input("Skip"); - norm->add_input("norm_scale"); - norm->add_output("A_norm"); - if (has_skip_passthrough) { - norm->add_output(""); - norm->add_output(""); - norm->add_output("SkipSum"); - } - } else { - norm->add_input("A"); - norm->add_input("norm_scale"); - norm->add_output("A_norm"); - } - auto* attr_epsilon = norm->add_attribute(); - attr_epsilon->set_name("epsilon"); - attr_epsilon->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT); - attr_epsilon->set_f(1e-6f); - - AddMatMulNBitsNode(*graph, "QMatMulNBitsDecode", "A_norm", "q_B", "q_scales", "Q", config.k, config.q_n, config.bits, config.block_size, config.accuracy_level); - AddMatMulNBitsNode(*graph, "KMatMulNBitsDecode", "A_norm", "k_B", "k_scales", "K", config.k, config.kv_n, config.bits, config.block_size, config.accuracy_level); - AddMatMulNBitsNode(*graph, "VMatMulNBitsDecode", "A_norm", "v_B", "v_scales", "V", config.k, config.kv_n, config.bits, config.block_size, config.accuracy_level); - } - - const auto serialized = model.SerializeAsString(); - return std::vector(serialized.begin(), serialized.end()); -} - -Ort::Session CreateSessionFromModelData(const std::vector& model_data, - const std::unordered_map* provider_options, - GraphOptimizationLevel graph_optimization_level = GraphOptimizationLevel::ORT_ENABLE_ALL) { - Ort::SessionOptions session_options; - session_options.DisableMemPattern(); - session_options.SetGraphOptimizationLevel(graph_optimization_level); - if (IsVerboseSessionLogEnabled()) { - session_options.SetLogSeverityLevel(0); - } - - const std::string optimized_model_path = GetOptimizedModelPath(); - if (!optimized_model_path.empty()) { - const auto optimized_model_path_ort = onnxruntime::ToWideString(optimized_model_path); - session_options.SetOptimizedModelFilePath(optimized_model_path_ort.c_str()); - } - - if (provider_options != nullptr) { - if (IsGraphCaptureBenchmarkEnabled()) { - session_options.AddConfigEntry(onnxruntime::webgpu::options::kEnableGraphCapture, - onnxruntime::webgpu::options::kEnableGraphCapture_ON); - } - session_options.AppendExecutionProvider("WebGPU", *provider_options); - } - - OrtSession* raw_session = nullptr; - OrtStatus* status = g_ort->CreateSessionFromArray(env, model_data.data(), model_data.size(), session_options, &raw_session); - if (status != nullptr) { - std::string error_message = g_ort->GetErrorMessage(status); - g_ort->ReleaseStatus(status); - throw std::runtime_error(error_message); - } - - return Ort::Session{raw_session}; -} - -void ValidateDecodeOutputs(const std::vector& model_data, - Ort::Session& webgpu_session, - const char* const* input_names, - const Ort::Value* input_tensor, - const char* const* output_names) { - Ort::Session cpu_session = CreateSessionFromModelData(model_data, nullptr); - - auto webgpu_outputs = webgpu_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensor, 1, output_names, 1); - auto cpu_outputs = cpu_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensor, 1, output_names, 1); - - if (webgpu_outputs.size() != 1 || cpu_outputs.size() != 1) { - throw std::runtime_error("Expected a single output from both WebGPU and CPU sessions."); - } - - const auto& webgpu_output = webgpu_outputs[0]; - const auto& cpu_output = cpu_outputs[0]; - const size_t element_count = webgpu_output.GetTensorTypeAndShapeInfo().GetElementCount(); - if (element_count != cpu_output.GetTensorTypeAndShapeInfo().GetElementCount()) { - throw std::runtime_error("WebGPU and CPU output sizes do not match."); - } - - const auto* webgpu_data = webgpu_output.GetTensorData(); - const auto* cpu_data = cpu_output.GetTensorData(); - float max_abs_diff = 0.0f; - size_t max_abs_diff_index = 0; - for (size_t i = 0; i < element_count; ++i) { - const float webgpu_value = webgpu_data[i].ToFloat(); - const float cpu_value = cpu_data[i].ToFloat(); - const float abs_diff = std::abs(webgpu_value - cpu_value); - const float allowed_diff = kDecodeCorrectnessAbsTolerance + - kDecodeCorrectnessRelTolerance * std::max(std::abs(webgpu_value), std::abs(cpu_value)); - if (abs_diff > max_abs_diff) { - max_abs_diff = abs_diff; - max_abs_diff_index = i; - } - if (abs_diff > allowed_diff) { - std::ostringstream stream; - stream << "Decode correctness check failed at index " << i - << ": webgpu=" << webgpu_value - << ", cpu=" << cpu_value - << ", abs_diff=" << abs_diff - << ", allowed_diff=" << allowed_diff; - throw std::runtime_error(stream.str()); - } - } - - std::cout << "Decode correctness check passed. max_abs_diff=" << max_abs_diff - << " at index " << max_abs_diff_index << std::endl; -} - -void ValidateMlpDecodeOutputs(const std::vector& unfused_model_data, - const std::vector& fused_model_data, - const std::unordered_map& provider_options, - const char* const* input_names, - const Ort::Value* input_tensors, - size_t input_count, - const char* const* output_names, - size_t output_count) { - Ort::Session unfused_session = CreateSessionFromModelData(unfused_model_data, - &provider_options, - GraphOptimizationLevel::ORT_DISABLE_ALL); - Ort::Session fused_session = CreateSessionFromModelData(fused_model_data, - &provider_options, - GraphOptimizationLevel::ORT_ENABLE_ALL); - - auto unfused_outputs = unfused_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensors, input_count, output_names, output_count); - auto fused_outputs = fused_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensors, input_count, output_names, output_count); - - for (size_t output_index = 0; output_index < output_count; ++output_index) { - const auto& unfused_output = unfused_outputs[output_index]; - const auto& fused_output = fused_outputs[output_index]; - const size_t element_count = unfused_output.GetTensorTypeAndShapeInfo().GetElementCount(); - if (element_count != fused_output.GetTensorTypeAndShapeInfo().GetElementCount()) { - throw std::runtime_error("Unfused and fused MLP output sizes do not match."); - } - - const auto* unfused_data = unfused_output.GetTensorData(); - const auto* fused_data = fused_output.GetTensorData(); - float max_abs_diff = 0.0f; - size_t max_abs_diff_index = 0; - for (size_t i = 0; i < element_count; ++i) { - const float unfused_value = unfused_data[i].ToFloat(); - const float fused_value = fused_data[i].ToFloat(); - const float abs_diff = std::abs(unfused_value - fused_value); - const float allowed_diff = kDecodeCorrectnessAbsTolerance + - kDecodeCorrectnessRelTolerance * std::max(std::abs(unfused_value), std::abs(fused_value)); - if (abs_diff > max_abs_diff) { - max_abs_diff = abs_diff; - max_abs_diff_index = i; - } - if (abs_diff > allowed_diff) { - std::ostringstream stream; - stream << "MLP decode correctness check failed on output " << output_index - << " at index " << i - << ": unfused=" << unfused_value - << ", fused=" << fused_value - << ", abs_diff=" << abs_diff - << ", allowed_diff=" << allowed_diff; - throw std::runtime_error(stream.str()); - } - } - - std::cout << "MLP decode correctness check passed for output " << output_index - << ". max_abs_diff=" << max_abs_diff - << " at index " << max_abs_diff_index << std::endl; - } -} - -void ValidateQkvDecodeOutputs(const std::vector& unfused_model_data, - const std::vector& fused_model_data, - const std::unordered_map& provider_options, - const char* const* input_names, - const Ort::Value* input_tensors, - size_t input_count, - const char* const* output_names, - size_t output_count) { - Ort::Session unfused_session = CreateSessionFromModelData(unfused_model_data, &provider_options, GraphOptimizationLevel::ORT_DISABLE_ALL); - Ort::Session fused_session = CreateSessionFromModelData(fused_model_data, &provider_options, GraphOptimizationLevel::ORT_DISABLE_ALL); - - auto unfused_outputs = unfused_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensors, input_count, output_names, output_count); - auto fused_outputs = fused_session.Run(Ort::RunOptions{nullptr}, input_names, input_tensors, input_count, output_names, output_count); - - for (size_t output_index = 0; output_index < output_count; ++output_index) { - const size_t element_count = unfused_outputs[output_index].GetTensorTypeAndShapeInfo().GetElementCount(); - const auto* unfused_data = unfused_outputs[output_index].GetTensorData(); - const auto* fused_data = fused_outputs[output_index].GetTensorData(); - for (size_t i = 0; i < element_count; ++i) { - const float unfused_value = unfused_data[i].ToFloat(); - const float fused_value = fused_data[i].ToFloat(); - const float abs_diff = std::abs(unfused_value - fused_value); - const float allowed_diff = kDecodeCorrectnessAbsTolerance + - kDecodeCorrectnessRelTolerance * std::max(std::abs(unfused_value), std::abs(fused_value)); - if (abs_diff > allowed_diff) { - std::ostringstream stream; - stream << "QKV decode correctness check failed on output " << output_index - << " at index " << i - << ": unfused=" << unfused_value - << ", fused=" << fused_value - << ", abs_diff=" << abs_diff - << ", allowed_diff=" << allowed_diff; - throw std::runtime_error(stream.str()); - } - } - } - - std::cout << "QKV decode correctness check passed." << std::endl; -} - -void BenchmarkWebGpuMatMulNBitsQkvDecode(benchmark::State& state, - QkvDecodeBenchmarkVariant variant, - QkvNormKind norm_kind) { - try { - const QkvDecodeBenchConfig config{ - state.range(0), - state.range(1), - state.range(2), - state.range(3), - state.range(4), - state.range(5), - }; - - if (config.k % config.block_size != 0) { - state.SkipWithError("K must be divisible by block_size for this benchmark skeleton."); - return; - } - - const QkvTrafficStats traffic = CalculateQkvTrafficStats(config, variant, norm_kind); - std::vector model_data = SerializeMatMulNBitsQkvModel(config, variant, norm_kind); - const SelectedWebGpuContext& selected_context = GetSelectedWebGpuContext(); - Ort::Session session = CreateSessionFromModelData(model_data, - &selected_context.provider_options, - GraphOptimizationLevel::ORT_DISABLE_ALL); - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); - std::vector input_shape{1, config.k}; - std::vector activation(static_cast(config.k)); - std::vector skip_activation(static_cast(config.k)); - - std::mt19937 rng(123); - std::uniform_real_distribution dist(-1.0f, 1.0f); - for (auto& value : activation) { - value = Ort::Float16_t(dist(rng)); - } - for (auto& value : skip_activation) { - value = Ort::Float16_t(dist(rng)); - } - - const bool has_skip = norm_kind == QkvNormKind::kSkipSimplified || - norm_kind == QkvNormKind::kSkipSimplifiedPassthrough; - const bool has_skip_passthrough = norm_kind == QkvNormKind::kSkipSimplifiedPassthrough; - const char* simplified_input_names[] = {"A"}; - const char* skip_input_names[] = {"A", "Skip"}; - const char* simplified_output_names[] = {"Q", "K", "V"}; - const char* skip_output_names[] = {"Q", "K", "V"}; - const char* skip_passthrough_output_names[] = {"Q", "K", "V", "SkipSum"}; - const char* const* input_names = has_skip ? skip_input_names : simplified_input_names; - const char* const* output_names = has_skip_passthrough ? skip_passthrough_output_names - : (has_skip ? skip_output_names : simplified_output_names); - const size_t input_count = has_skip ? 2u : 1u; - const size_t output_count = has_skip_passthrough ? 4u : 3u; - - std::array input_tensors = { - Ort::Value::CreateTensor(memory_info, - activation.data(), - activation.size(), - input_shape.data(), - input_shape.size()), - Ort::Value::CreateTensor(memory_info, - skip_activation.data(), - skip_activation.size(), - input_shape.data(), - input_shape.size())}; - Ort::RunOptions run_options = CreateBenchmarkRunOptions(); - - if (!IsDecodeBenchmarkPerfMode() && variant == QkvDecodeBenchmarkVariant::kFused) { - ValidateQkvDecodeOutputs(SerializeMatMulNBitsQkvModel(config, QkvDecodeBenchmarkVariant::kUnfused, norm_kind), - model_data, - selected_context.provider_options, - input_names, - input_tensors.data(), - input_count, - output_names, - output_count); - } - - for (int i = 0; i < kDecodeWarmupRuns; ++i) { - auto warmup_outputs = session.Run(run_options, input_names, input_tensors.data(), input_count, output_names, output_count); - benchmark::DoNotOptimize(warmup_outputs); - } - - double total_kernel_seconds = 0.0; - for (auto _ : state) { - const auto kernel_start = std::chrono::steady_clock::now(); - auto outputs = session.Run(run_options, input_names, input_tensors.data(), input_count, output_names, output_count); - const auto kernel_end = std::chrono::steady_clock::now(); - total_kernel_seconds += std::chrono::duration(kernel_end - kernel_start).count(); - benchmark::DoNotOptimize(outputs); - } - - const double total_flops = 2.0 * static_cast(config.k) * static_cast(config.q_n + 2 * config.kv_n); - const double achieved_bandwidth_bytes_per_second = - total_kernel_seconds > 0.0 - ? traffic.total_bytes * static_cast(state.iterations()) / total_kernel_seconds - : 0.0; - - state.SetLabel(GetQkvDecodeBenchmarkLabel(variant, norm_kind)); - state.counters["TFLOPS"] = benchmark::Counter(total_flops, benchmark::Counter::kIsIterationInvariantRate); - state.counters["ApproxMemBW_GBps"] = benchmark::Counter(achieved_bandwidth_bytes_per_second / 1.0e9); - state.counters["ApproxTraffic_MB"] = benchmark::Counter(traffic.total_bytes / 1.0e6); - state.counters["Input_MB"] = benchmark::Counter(traffic.input_bytes / 1.0e6); - state.counters["SkipInput_MB"] = benchmark::Counter(traffic.skip_input_bytes / 1.0e6); - state.counters["NormScale_MB"] = benchmark::Counter(traffic.norm_scale_bytes / 1.0e6); - state.counters["PackedW_MB"] = benchmark::Counter(traffic.packed_weight_bytes / 1.0e6); - state.counters["Scales_MB"] = benchmark::Counter(traffic.scale_bytes / 1.0e6); - state.counters["Intermediate_MB"] = benchmark::Counter(traffic.intermediate_bytes / 1.0e6); - state.counters["Output_MB"] = benchmark::Counter(traffic.output_bytes / 1.0e6); - state.counters["GraphReplay"] = benchmark::Counter(IsGraphCaptureBenchmarkEnabled() ? 1.0 : 0.0); - } catch (const std::exception& ex) { - state.SkipWithError(ex.what()); - } -} - -void BenchmarkWebGpuMatMulNBitsMlpDecode(benchmark::State& state, - MlpDecodeBenchmarkVariant variant, - MlpNormKind norm_kind) { - try { - const MlpDecodeBenchConfig config{ - state.range(0), - state.range(1), - state.range(2), - state.range(3), - state.range(4), - }; - - if (config.k % config.block_size != 0) { - state.SkipWithError("K must be divisible by block_size for this benchmark skeleton."); - return; - } - - const MlpTrafficStats traffic = CalculateMlpTrafficStats(config, variant, norm_kind); - std::vector model_data = SerializeMatMulNBitsMlpModel(config, variant, norm_kind); - const SelectedWebGpuContext& selected_context = GetSelectedWebGpuContext(); - const GraphOptimizationLevel optimization_level = - variant == MlpDecodeBenchmarkVariant::kUnfused ? GraphOptimizationLevel::ORT_DISABLE_ALL - : GraphOptimizationLevel::ORT_ENABLE_ALL; - Ort::Session session = CreateSessionFromModelData(model_data, - &selected_context.provider_options, - optimization_level); - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); - std::vector input_shape{1, config.k}; - std::vector activation(static_cast(config.k)); - std::vector skip_activation(static_cast(config.k)); - std::mt19937 rng(123); - std::uniform_real_distribution dist(-1.0f, 1.0f); - for (auto& value : activation) { - value = Ort::Float16_t(dist(rng)); - } - for (auto& value : skip_activation) { - value = Ort::Float16_t(dist(rng)); - } - - const bool has_skip = norm_kind == MlpNormKind::kSkipSimplified || - norm_kind == MlpNormKind::kSkipSimplifiedPassthrough; - const bool has_skip_passthrough = norm_kind == MlpNormKind::kSkipSimplifiedPassthrough; - const char* simplified_input_names[] = {"A"}; - const char* skip_input_names[] = {"A", "Skip"}; - const char* main_output_names[] = {"Y"}; - const char* skip_passthrough_output_names[] = {"Y", "SkipSum"}; - const char* const* input_names = has_skip ? skip_input_names : simplified_input_names; - const char* const* output_names = has_skip_passthrough ? skip_passthrough_output_names : main_output_names; - const size_t input_count = has_skip ? 2u : 1u; - const size_t output_count = has_skip_passthrough ? 2u : 1u; - std::array input_tensors = { - Ort::Value::CreateTensor(memory_info, - activation.data(), - activation.size(), - input_shape.data(), - input_shape.size()), - Ort::Value::CreateTensor(memory_info, - skip_activation.data(), - skip_activation.size(), - input_shape.data(), - input_shape.size())}; - Ort::RunOptions run_options = CreateBenchmarkRunOptions(); - - if (!IsDecodeBenchmarkPerfMode()) { - ValidateMlpDecodeOutputs(SerializeMatMulNBitsMlpModel(config, MlpDecodeBenchmarkVariant::kUnfused, norm_kind), - SerializeMatMulNBitsMlpModel(config, variant, norm_kind), - selected_context.provider_options, - input_names, - input_tensors.data(), - input_count, - output_names, - output_count); - } - - for (int i = 0; i < kDecodeWarmupRuns; ++i) { - auto warmup_outputs = session.Run(run_options, input_names, input_tensors.data(), input_count, output_names, output_count); - benchmark::DoNotOptimize(warmup_outputs); - } - - double total_kernel_seconds = 0.0; - for (auto _ : state) { - const auto kernel_start = std::chrono::steady_clock::now(); - auto outputs = session.Run(run_options, input_names, input_tensors.data(), input_count, output_names, output_count); - const auto kernel_end = std::chrono::steady_clock::now(); - total_kernel_seconds += std::chrono::duration(kernel_end - kernel_start).count(); - benchmark::DoNotOptimize(outputs); - } - - const double total_flops = 4.0 * static_cast(config.n) * static_cast(config.k); - const double achieved_bandwidth_bytes_per_second = - total_kernel_seconds > 0.0 - ? traffic.total_bytes * static_cast(state.iterations()) / total_kernel_seconds - : 0.0; - - state.SetLabel(GetMlpDecodeBenchmarkLabel(variant, norm_kind)); - state.counters["TFLOPS"] = benchmark::Counter( - total_flops, - benchmark::Counter::kIsIterationInvariantRate); - state.counters["ApproxMemBW_GBps"] = benchmark::Counter(achieved_bandwidth_bytes_per_second / 1.0e9); - state.counters["ApproxTraffic_MB"] = benchmark::Counter(traffic.total_bytes / 1.0e6); - state.counters["Input_MB"] = benchmark::Counter(traffic.input_bytes / 1.0e6); - state.counters["PackedW_MB"] = benchmark::Counter(traffic.packed_weight_bytes / 1.0e6); - state.counters["Scales_MB"] = benchmark::Counter(traffic.scale_bytes / 1.0e6); - state.counters["Intermediate_MB"] = benchmark::Counter(traffic.intermediate_bytes / 1.0e6); - state.counters["Output_MB"] = benchmark::Counter(traffic.output_bytes / 1.0e6); - state.counters["GraphReplay"] = benchmark::Counter(IsGraphCaptureBenchmarkEnabled() ? 1.0 : 0.0); - } catch (const std::exception& ex) { - state.SkipWithError(ex.what()); - } -} - -static void BM_WebGpuMatMulNBitsMlpSimplifiedDecodeUnfused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kUnfused, MlpNormKind::kSimplified); -} - -static void BM_WebGpuMatMulNBitsMlpSimplifiedDecodeFused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kFused, MlpNormKind::kSimplified); -} - -static void BM_WebGpuMatMulNBitsQkvSimplifiedNormDecodeUnfused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kUnfused, QkvNormKind::kSimplified); -} - -static void BM_WebGpuMatMulNBitsQkvSimplifiedNormDecodeFused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kFused, QkvNormKind::kSimplified); -} - -static void BM_WebGpuMatMulNBitsQkvSkipDecodeUnfused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kUnfused, QkvNormKind::kSkipSimplified); -} - -static void BM_WebGpuMatMulNBitsQkvSkipDecodeFused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kFused, QkvNormKind::kSkipSimplified); -} - -static void BM_WebGpuMatMulNBitsQkvSkipPassthroughDecodeUnfused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kUnfused, QkvNormKind::kSkipSimplifiedPassthrough); -} - -static void BM_WebGpuMatMulNBitsQkvSkipPassthroughDecodeFused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsQkvDecode(state, QkvDecodeBenchmarkVariant::kFused, QkvNormKind::kSkipSimplifiedPassthrough); -} - -static void BM_WebGpuMatMulNBitsMlpSkipDecodeUnfused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kUnfused, MlpNormKind::kSkipSimplified); -} - -static void BM_WebGpuMatMulNBitsMlpSkipDecodeFused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kFused, MlpNormKind::kSkipSimplified); -} - -static void BM_WebGpuMatMulNBitsMlpSkipPassthroughDecodeUnfused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kUnfused, MlpNormKind::kSkipSimplifiedPassthrough); -} - -static void BM_WebGpuMatMulNBitsMlpSkipPassthroughDecodeFused(benchmark::State& state) { - BenchmarkWebGpuMatMulNBitsMlpDecode(state, MlpDecodeBenchmarkVariant::kFused, MlpNormKind::kSkipSimplifiedPassthrough); -} - -void ApplyWebGpuMatMulNBitsMlpDecodeArgs(benchmark::internal::Benchmark* benchmark) { - for (const auto& config : GetMlpDecodeBenchConfigs()) { - benchmark->Args({config.n, config.k, config.bits, config.block_size, config.accuracy_level}); - } -} - -void ApplyWebGpuMatMulNBitsQkvDecodeArgs(benchmark::internal::Benchmark* benchmark) { - for (const auto& config : GetQkvDecodeBenchConfigs()) { - benchmark->Args({config.q_n, config.kv_n, config.k, config.bits, config.block_size, config.accuracy_level}); - } -} - -// Qkv benchmarks -BENCHMARK(BM_WebGpuMatMulNBitsQkvSimplifiedNormDecodeUnfused) - ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); - -BENCHMARK(BM_WebGpuMatMulNBitsQkvSimplifiedNormDecodeFused) - ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); - -BENCHMARK(BM_WebGpuMatMulNBitsQkvSkipDecodeUnfused) - ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); - -BENCHMARK(BM_WebGpuMatMulNBitsQkvSkipDecodeFused) - ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); - -BENCHMARK(BM_WebGpuMatMulNBitsQkvSkipPassthroughDecodeUnfused) - ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); - -BENCHMARK(BM_WebGpuMatMulNBitsQkvSkipPassthroughDecodeFused) - ->Apply(ApplyWebGpuMatMulNBitsQkvDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); - -// Mlp benchmarks -BENCHMARK(BM_WebGpuMatMulNBitsMlpSimplifiedDecodeUnfused) - ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); - -BENCHMARK(BM_WebGpuMatMulNBitsMlpSimplifiedDecodeFused) - ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); - -BENCHMARK(BM_WebGpuMatMulNBitsMlpSkipDecodeUnfused) - ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); - -BENCHMARK(BM_WebGpuMatMulNBitsMlpSkipDecodeFused) - ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); - -BENCHMARK(BM_WebGpuMatMulNBitsMlpSkipPassthroughDecodeUnfused) - ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); - -BENCHMARK(BM_WebGpuMatMulNBitsMlpSkipPassthroughDecodeFused) - ->Apply(ApplyWebGpuMatMulNBitsMlpDecodeArgs) - ->ReportAggregatesOnly() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kMicrosecond); - -} // namespace - -#endif // USE_WEBGPU From 2039c7f2099fa17acca63f3753a9d70011601443 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 1 May 2026 21:00:03 -0700 Subject: [PATCH 15/26] Remove unused dp4a_matmul_mlp.wgsl.template This template file was added speculatively but is not referenced by any kernel, include, or build rule. Removing to keep the PR clean. --- .../dp4a_matmul_mlp.wgsl.template | 110 ------------------ 1 file changed, 110 deletions(-) delete mode 100644 onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_mlp.wgsl.template diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_mlp.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_mlp.wgsl.template deleted file mode 100644 index 77dc522130f82..0000000000000 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_mlp.wgsl.template +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#param tile_size -#param tile_size_k_vec -#param single_scale_weights -#param has_gate_bias -#param has_up_bias - -#use .getByOffset .setByOffset - -#include "quantization/dp4a_matmul_common.wgsl.template" - -const double_tile_size_k_vec = 2 * tile_size_k_vec; -const scale_a_size_in_tile_a = double_tile_size_k_vec / 8; - -var gate_inter_results: array, tile_size>; -var up_inter_results: array, tile_size>; -var tile_A: array, double_tile_size_k_vec>; -var scale_A: array; - -fn loadSHMA(batch: u32, kidx_v: u32, col: u32) { - let k_offset = kidx_v + col; - if (k_offset >= uniforms.K16) { - return; - } - - tile_A[col] = a.getByOffset(batch * uniforms.K16 + k_offset); - if (col < scale_a_size_in_tile_a) { - scale_A[col] = scales_a.getByOffset(batch * (uniforms.K / 128) + kidx_v / 8 + col); - } -} - -$MAIN { - let batch = workgroup_id.z; - if (batch >= uniforms.batch_count) { - return; - } - - let b_global_base = workgroup_id.x * tile_size; - let local_col = local_idx % tile_size_k_vec; - let local_row = local_idx / tile_size_k_vec; - - if (local_idx < tile_size) { - for (var lane = 0u; lane < tile_size_k_vec; lane++) { - gate_inter_results[local_idx][lane] = output_element_t(0); - up_inter_results[local_idx][lane] = output_element_t(0); - } - } - workgroupBarrier(); - -#if single_scale_weights - let gate_scale_b = gate_scales_b.getByOffset(0); - let up_scale_b = up_scales_b.getByOffset(0); -#endif - - for (var kidx_v: u32 = 0u; kidx_v < uniforms.K32; kidx_v += tile_size_k_vec) { - if (local_idx < double_tile_size_k_vec) { - loadSHMA(batch, kidx_v * 2u, local_idx); - } - workgroupBarrier(); - - let own_a0 = tile_A[local_col * 2u]; - let own_a1 = tile_A[local_col * 2u + 1u]; - let own_scale_a = scale_A[local_col / 4u]; - let k_offset = kidx_v + local_col; - let block_idx = k_offset * 32u / uniforms.block_size; - - let b_global = b_global_base + local_row; - if (b_global < uniforms.N && k_offset < uniforms.K32) { -#if !single_scale_weights - let gate_scale_b = gate_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx); - let up_scale_b = up_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx); -#endif - let gate_b_value = gate_b.getByOffset(b_global * uniforms.K32 + k_offset); - let up_b_value = up_b.getByOffset(b_global * uniforms.K32 + k_offset); - let gate_b0 = DequantizedFrom4BitsTo8Bits(gate_b_value.xy, default_zero_point); - let gate_b1 = DequantizedFrom4BitsTo8Bits(gate_b_value.zw, default_zero_point); - let up_b0 = DequantizedFrom4BitsTo8Bits(up_b_value.xy, default_zero_point); - let up_b1 = DequantizedFrom4BitsTo8Bits(up_b_value.zw, default_zero_point); - let gate_scale = own_scale_a * gate_scale_b; - let up_scale = own_scale_a * up_scale_b; - gate_inter_results[local_row][local_col] += SDP8AI(own_a0, gate_b0, own_a1, gate_b1, gate_scale); - up_inter_results[local_row][local_col] += SDP8AI(own_a0, up_b0, own_a1, up_b1, up_scale); - } - workgroupBarrier(); - } - - if (local_idx < tile_size) { - var gate_output_value = output_element_t(0); - var up_output_value = output_element_t(0); - for (var lane = 0u; lane < tile_size_k_vec; lane++) { - gate_output_value += gate_inter_results[local_idx][lane]; - up_output_value += up_inter_results[local_idx][lane]; - } - - let b_global = b_global_base + local_idx; - if (b_global < uniforms.N) { -#if has_gate_bias - gate_output_value += gate_bias[b_global]; -#endif -#if has_up_bias - up_output_value += up_bias[b_global]; -#endif - let one = output_element_t(1.0); - let silu_value = gate_output_value * (one / (one + exp(-gate_output_value))); - output.setByOffset(batch * uniforms.N + b_global, silu_value * up_output_value); - } - } -} From a02cf12537a7a36e23bfecca52b8b008b9b12c54 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 1 May 2026 21:01:06 -0700 Subject: [PATCH 16/26] Cleanup: drop unused empty namespace + env_var_utils include in graph_transformer_utils --- onnxruntime/core/optimizer/graph_transformer_utils.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index adf3119be1244..36f3cdda81dc6 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -13,7 +13,6 @@ #include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h" -#include "core/platform/env_var_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/platform/threadpool.h" @@ -102,10 +101,6 @@ namespace onnxruntime::optimizer_utils { -namespace { - -} // namespace - static void FilterTransformers(InlinedVector>& transformers, const InlinedHashSet& transformers_to_disable) { if (transformers_to_disable.empty()) return; From 906506343ce3f333e1761d47ede68b05b8220152 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 1 May 2026 22:35:30 -0700 Subject: [PATCH 17/26] Copilot comments --- .../webgpu/quantization/matmul_nbits.cc | 24 +-- .../quantization/matmul_nbits_common.cc | 141 ++++++++++++------ .../webgpu/quantization/matmul_nbits_common.h | 31 ++++ .../webgpu/quantization/matmul_nbits_mlp.cc | 24 +-- .../webgpu/quantization/matmul_nbits_qkv.cc | 26 ++-- .../core/optimizer/matmul_nbits_mlp_fusion.h | 12 +- .../core/optimizer/matmul_nbits_qkv_fusion.cc | 56 ++++++- .../optimizer/matmul_nbits_qkv_fusion_test.cc | 32 +++- 8 files changed, 252 insertions(+), 94 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 70c4ddfb19c04..e0a78aab1220b 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -224,10 +224,11 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, #if !defined(__wasm__) // apple|intel - Experimental dawn support for subgroup matrix matmul. int32_t subgroup_matrix_config_index = -1; - if (WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(a, - K_op, - N_op, - block_size_op, + if (WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(M, + N, + K, + batch_count, + block_size, accuracy_level, nbits, context, @@ -241,10 +242,10 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, // On FP32 only GPUs and Qualcomm GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M. // DP4A Q2 path now supports custom zero points via a 1024-entry LUT (4 zero-point sections × 256 byte values). - if (WouldApplyDP4AMatMulNBitsInCurrentDispatch(a, - K_op, - N_op, - block_size_op, + if (WouldApplyDP4AMatMulNBitsInCurrentDispatch(M, + N, + K, + block_size, accuracy_level, context, y, @@ -254,10 +255,9 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, // WideTileProgram // This program is optimized for Block32 prefill using Tile16x128. - const bool use_wide_tile_program = WouldApplyWideTileMatMulNBitsInCurrentDispatch(a, - K_op, - N_op, - block_size_op, + const bool use_wide_tile_program = WouldApplyWideTileMatMulNBitsInCurrentDispatch(M, + K, + block_size, nbits, has_weight_idx_indirect); diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc index e615b90577f61..488194a60a31a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc @@ -68,29 +68,19 @@ bool HasDP4ADeviceSupport(int context_id) { ctx.AdapterInfo().vendor != std::string_view{"apple"}; } -bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(const Tensor* a, - int64_t K_op, - int64_t N_op, - int64_t block_size_op, - int64_t accuracy_level, - int64_t nbits, - onnxruntime::webgpu::ComputeContext& context, - Tensor* y, - bool has_weight_idx_indirect, - int32_t* subgroup_matrix_config_index, - uint32_t override_M) { - TensorShape b_shape({N_op, K_op}); - MatMulComputeHelper helper; - if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { - return false; - } - - const uint32_t M = onnxruntime::narrow(helper.M()); - [[maybe_unused]] const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size()); +bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(uint32_t M, + [[maybe_unused]] uint32_t N, + [[maybe_unused]] uint32_t K, + [[maybe_unused]] uint32_t batch_count, + [[maybe_unused]] uint32_t block_size, + [[maybe_unused]] int64_t accuracy_level, + [[maybe_unused]] int64_t nbits, + [[maybe_unused]] onnxruntime::webgpu::ComputeContext& context, + [[maybe_unused]] Tensor* y, + [[maybe_unused]] bool has_weight_idx_indirect, + [[maybe_unused]] int32_t* subgroup_matrix_config_index, + [[maybe_unused]] uint32_t override_M) { [[maybe_unused]] const uint32_t dispatch_M = override_M > 0 ? override_M : M; - [[maybe_unused]] const uint32_t N = onnxruntime::narrow(helper.N()); - [[maybe_unused]] const uint32_t K = onnxruntime::narrow(helper.K()); - [[maybe_unused]] const uint32_t block_size = onnxruntime::narrow(block_size_op); #if !defined(__wasm__) int32_t local_subgroup_matrix_config_index = -1; @@ -108,9 +98,57 @@ bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(const Tensor* a, static_cast(nbits), y->DataType() == DataTypeImpl::GetType(), subgroup_matrix_config_index != nullptr ? *subgroup_matrix_config_index : local_subgroup_matrix_config_index); +#else + return false; #endif +} - return false; +bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t accuracy_level, + int64_t nbits, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + bool has_weight_idx_indirect, + int32_t* subgroup_matrix_config_index, + uint32_t override_M) { + TensorShape b_shape({N_op, K_op}); + MatMulComputeHelper helper; + if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { + return false; + } + + return WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch( + onnxruntime::narrow(helper.M()), + onnxruntime::narrow(helper.N()), + onnxruntime::narrow(helper.K()), + onnxruntime::narrow(helper.OutputOffsets().size()), + onnxruntime::narrow(block_size_op), + accuracy_level, + nbits, + context, + y, + has_weight_idx_indirect, + subgroup_matrix_config_index, + override_M); +} + +bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(uint32_t M, + uint32_t N, + uint32_t K, + uint32_t block_size, + int64_t accuracy_level, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + bool has_weight_idx_indirect) { + const uint32_t components_a = GetMaxComponents(K); + + return ((M >= kMinMForTileOptimization && !has_weight_idx_indirect) || + y->DataType() == DataTypeImpl::GetType() || + context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && + CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a); } bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(const Tensor* a, @@ -127,16 +165,37 @@ bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(const Tensor* a, return false; } - const uint32_t M = onnxruntime::narrow(helper.M()); - const uint32_t N = onnxruntime::narrow(helper.N()); - const uint32_t K = onnxruntime::narrow(helper.K()); - const uint32_t block_size = onnxruntime::narrow(block_size_op); + return WouldApplyDP4AMatMulNBitsInCurrentDispatch( + onnxruntime::narrow(helper.M()), + onnxruntime::narrow(helper.N()), + onnxruntime::narrow(helper.K()), + onnxruntime::narrow(block_size_op), + accuracy_level, + context, + y, + has_weight_idx_indirect); +} + +bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(uint32_t M, + uint32_t K, + uint32_t block_size, + int64_t nbits, + bool has_weight_idx_indirect) { + if (has_weight_idx_indirect) { + return false; + } + const uint32_t components_a = GetMaxComponents(K); + const uint32_t block_size_per_col = block_size; + const uint32_t blob_size = (block_size_per_col / 8) * static_cast(nbits); + const uint32_t blob_size_in_words = blob_size / 4; + const uint32_t components_b = GetMaxComponents(blob_size_in_words); - return ((M >= kMinMForTileOptimization && !has_weight_idx_indirect) || - y->DataType() == DataTypeImpl::GetType() || - context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && - CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a); + return block_size == 32 && + components_a == 4 && + components_b == 4 && + nbits != 2 && + M >= kMinMForTileOptimization; } bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(const Tensor* a, @@ -155,20 +214,12 @@ bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(const Tensor* a, return false; } - const uint32_t M = onnxruntime::narrow(helper.M()); - const uint32_t K = onnxruntime::narrow(helper.K()); - const uint32_t block_size = onnxruntime::narrow(block_size_op); - const uint32_t components_a = GetMaxComponents(K); - const uint32_t block_size_per_col = block_size; - const uint32_t blob_size = (block_size_per_col / 8) * static_cast(nbits); - const uint32_t blob_size_in_words = blob_size / 4; - const uint32_t components_b = GetMaxComponents(blob_size_in_words); - - return block_size == 32 && - components_a == 4 && - components_b == 4 && - nbits != 2 && - M >= kMinMForTileOptimization; + return WouldApplyWideTileMatMulNBitsInCurrentDispatch( + onnxruntime::narrow(helper.M()), + onnxruntime::narrow(helper.K()), + onnxruntime::narrow(block_size_op), + nbits, + has_weight_idx_indirect); } } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h index 883c6be02baa8..f3277e536ae62 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h @@ -48,6 +48,22 @@ bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(const Tensor* a, int32_t* subgroup_matrix_config_index = nullptr, uint32_t override_M = 0); +// Precomputed-dims overload for callers (e.g., ApplyMatMulNBits, MatMulNBitsMlp, +// MatMulNBitsQkv) that have already run MatMulComputeHelper and have M/N/K and +// batch_count in scope. Avoids re-running shape inference per dispatch decision. +bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(uint32_t M, + uint32_t N, + uint32_t K, + uint32_t batch_count, + uint32_t block_size, + int64_t accuracy_level, + int64_t nbits, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + bool has_weight_idx_indirect = false, + int32_t* subgroup_matrix_config_index = nullptr, + uint32_t override_M = 0); + bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(const Tensor* a, int64_t K_op, int64_t N_op, @@ -57,6 +73,15 @@ bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(const Tensor* a, Tensor* y, bool has_weight_idx_indirect = false); +bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(uint32_t M, + uint32_t N, + uint32_t K, + uint32_t block_size, + int64_t accuracy_level, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + bool has_weight_idx_indirect = false); + bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(const Tensor* a, int64_t K_op, int64_t N_op, @@ -64,6 +89,12 @@ bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(const Tensor* a, int64_t nbits, bool has_weight_idx_indirect = false); +bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(uint32_t M, + uint32_t K, + uint32_t block_size, + int64_t nbits, + bool has_weight_idx_indirect = false); + } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc index c89850567e7ad..4effbb82a8a26 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc @@ -445,27 +445,27 @@ Status MatMulNBitsMlp::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont const bool has_norm_input = norm_scale != nullptr; const bool would_use_subgroup_unfused = - WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(a, - K_, - N_, - block_size_, + WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(M, + N, + K, + batch_count, + block_size, accuracy_level_, bits_, context, y); const bool would_use_dp4a_unfused = - WouldApplyDP4AMatMulNBitsInCurrentDispatch(a, - K_, - N_, - block_size_, + WouldApplyDP4AMatMulNBitsInCurrentDispatch(M, + N, + K, + block_size, accuracy_level_, context, y); const bool would_use_wide_tile_unfused = - WouldApplyWideTileMatMulNBitsInCurrentDispatch(a, - K_, - N_, - block_size_, + WouldApplyWideTileMatMulNBitsInCurrentDispatch(M, + K, + block_size, bits_); const bool can_use_decode_fast_path = diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc index b260489227243..aad6e57d19d22 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc @@ -392,31 +392,32 @@ Status MatMulNBitsQkv::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont ORT_ENFORCE(norm_scale->Shape().Size() == K_, "norm_scale must have shape [K]."); + const uint32_t block_size = onnxruntime::narrow(block_size_); const bool would_use_subgroup_unfused = - WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(a, - K_, - Nq_, - block_size_, + WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(M, + Nq, + K, + batch_count, + block_size, accuracy_level_, bits_, context, q_output); const bool would_use_dp4a_unfused = !would_use_subgroup_unfused && - WouldApplyDP4AMatMulNBitsInCurrentDispatch(a, - K_, - Nq_, - block_size_, + WouldApplyDP4AMatMulNBitsInCurrentDispatch(M, + Nq, + K, + block_size, accuracy_level_, context, q_output); const bool would_use_wide_tile_unfused = !would_use_subgroup_unfused && !would_use_dp4a_unfused && - WouldApplyWideTileMatMulNBitsInCurrentDispatch(a, - K_, - Nq_, - block_size_, + WouldApplyWideTileMatMulNBitsInCurrentDispatch(M, + K, + block_size, bits_); if (would_use_subgroup_unfused || would_use_dp4a_unfused || would_use_wide_tile_unfused || M != 1) { @@ -464,7 +465,6 @@ Status MatMulNBitsQkv::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont v_output); } - const uint32_t block_size = onnxruntime::narrow(block_size_); const uint32_t components_a = GetMaxComponents(K); const uint32_t block_size_per_col = block_size; const uint32_t n_blocks_per_col = (K + block_size_per_col - 1) / block_size_per_col; diff --git a/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h index d201256b93f91..007d21027dca0 100644 --- a/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h +++ b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h @@ -7,19 +7,23 @@ namespace onnxruntime { -// Fuses the SwiGLU MLP block (gate / up / down MatMulNBits projections around a +// Fuses the SwiGLU gated-activation subgraph (gate / up MatMulNBits projections around a // SimplifiedLayerNormalization anchor) into a single MatMulNBitsMlp contrib op: // // ... -> [Skip]SimplifiedLayerNormalization -+-> MatMulNBits (gate) -+-> Sigmoid -+ // | | | v // | | +----------> Mul (silu) -+ -// | +-> MatMulNBits (up) ---------------------------+--> Mul -> MatMulNBits (down) -> out +// | +-> MatMulNBits (up) ---------------------------+--> Mul -> out // +--> (optional) skip residual passthrough --> downstream consumers // // becomes // -// ... -> [Skip]SimplifiedLayerNormalization --> MatMulNBitsMlp(activation="silu") -+-> out -// +-> (optional) residual passthrough +// ... -> MatMulNBitsMlp(activation="silu") -+-> out +// +-> (optional) residual passthrough +// +// The downstream "down" projection (a third MatMulNBits that follows the gated-activation +// output) is intentionally NOT part of this fusion -- it remains a separate MatMulNBits node +// in the resulting graph. // // Only activation="silu" (i.e. x * Sigmoid(x)) is matched / emitted, and the fusion is restricted // to the WebGPU EP because MatMulNBitsMlp is a WebGPU-only contrib op. diff --git a/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc index e29cd3e4e0030..840d1e99ba1bd 100644 --- a/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc @@ -60,8 +60,40 @@ struct QkvNodes { const Node* v = nullptr; }; +bool IsGraphOutput(const Graph& graph, const Node& node, size_t index) { + if (!HasProducedOutput(node, index)) { + return false; + } + const auto& output_name = node.OutputDefs()[index]->Name(); + for (const auto* graph_output : graph.GetOutputs()) { + if (graph_output != nullptr && graph_output->Name() == output_name) { + return true; + } + } + return false; +} + +// Output 0 of the norm is consumed by the fused op, so it must not be a graph output. +// For SkipSimplifiedLayerNormalization the optional residual sum at output 3 is +// preserved by the fused MatMulNBitsQkv op, so it is allowed to remain a graph output. +// Outputs 1 and 2 (mean / inv_std_var) are not exposed by the fused op. +bool IsSupportedNormGraphOutputsForFusion(const Graph& graph, const Node& norm) { + if (IsGraphOutput(graph, norm, 0)) { + return false; + } + for (size_t i = 1; i < norm.OutputDefs().size(); ++i) { + if (!IsGraphOutput(graph, norm, i)) { + continue; + } + if (!(IsSupportedSkipSimplifiedLayerNormalization(norm) && i == 3)) { + return false; + } + } + return true; +} + std::optional GetQkvNodes(const Graph& graph, const Node& norm) { - if (!HasProducedOutput(norm, 0) || graph.NodeProducesGraphOutput(norm)) { + if (!HasProducedOutput(norm, 0) || !IsSupportedNormGraphOutputsForFusion(graph, norm)) { return std::nullopt; } @@ -235,6 +267,9 @@ Status MatMulNBitsQkvFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l const std::string v_name = qkv_nodes->v->Name(); const auto norm_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(node); + const auto q_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*qkv_nodes->q); + const auto k_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*qkv_nodes->k); + const auto v_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*qkv_nodes->v); const auto q_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(*qkv_nodes->q); const auto k_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(*qkv_nodes->k); const auto v_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(*qkv_nodes->v); @@ -272,6 +307,25 @@ Status MatMulNBitsQkvFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l graph.AddEdge(input_edge.src_node, fused_node.Index(), input_edge.src_arg_index, fused_input_index); } + // Q/K/V weight + scale tensors are usually initializers, but if any of them is produced + // by an upstream node we must rewire that producer edge to the fused input slot. + auto add_input_edge_if_present = [&](const std::vector& edges, + int source_input_index, + int fused_input_index) { + for (const auto& input_edge : edges) { + if (input_edge.dst_arg_index == source_input_index) { + graph.AddEdge(input_edge.src_node, fused_node.Index(), input_edge.src_arg_index, fused_input_index); + } + } + }; + + add_input_edge_if_present(q_input_edges, 1, 3); // q_weight + add_input_edge_if_present(q_input_edges, 2, 4); // q_scales + add_input_edge_if_present(k_input_edges, 1, 5); // k_weight + add_input_edge_if_present(k_input_edges, 2, 6); // k_scales + add_input_edge_if_present(v_input_edges, 1, 7); // v_weight + add_input_edge_if_present(v_input_edges, 2, 8); // v_scales + for (const auto& output_edge : q_output_edges) { graph.AddEdge(fused_node.Index(), output_edge.dst_node, 0, output_edge.dst_arg_index); } diff --git a/onnxruntime/test/optimizer/matmul_nbits_qkv_fusion_test.cc b/onnxruntime/test/optimizer/matmul_nbits_qkv_fusion_test.cc index 853d641c755a1..3f84e630067e8 100644 --- a/onnxruntime/test/optimizer/matmul_nbits_qkv_fusion_test.cc +++ b/onnxruntime/test/optimizer/matmul_nbits_qkv_fusion_test.cc @@ -36,7 +36,7 @@ NodeAttributes MakeMatMulNBitsAttrs(int64_t k, int64_t n, int64_t block_size, in return attrs; } -Status CheckMatMulNBitsQkvFusedGraphImpl(const Graph& graph, bool expect_skip_sln_output) { +Status CheckMatMulNBitsQkvFusedGraphImpl(const Graph& graph, bool expect_skip_sln_output, bool expect_skip_input) { const auto op_to_count = CountOpsInGraph(graph); if (OpCount(op_to_count, "com.microsoft.MatMulNBitsQkv") != 1 || OpCount(op_to_count, "SimplifiedLayerNormalization") != 0 || @@ -54,6 +54,12 @@ Status CheckMatMulNBitsQkvFusedGraphImpl(const Graph& graph, bool expect_skip_sl ORT_RETURN_IF_NOT(node.InputDefs().size() == 9, "Fused node must expose the 9-input contract."); ORT_RETURN_IF_NOT(node.OutputDefs().size() == (expect_skip_sln_output ? 4u : 3u), "Fused node outputs did not match the expected simplified vs skip-simplified contract."); + // skip is at input index 1; for the SkipSimplifiedLayerNormalization-anchored pattern it + // must be wired to a real NodeArg, otherwise it must be the empty optional. + const auto* skip_def = node.InputDefs()[1]; + const bool skip_present = skip_def != nullptr && skip_def->Exists(); + ORT_RETURN_IF_NOT(skip_present == expect_skip_input, + "Fused node skip-input presence did not match the expected pattern variant."); } } @@ -61,15 +67,21 @@ Status CheckMatMulNBitsQkvFusedGraphImpl(const Graph& graph, bool expect_skip_sl } Status CheckMatMulNBitsQkvFusedGraph(Graph& graph) { - return CheckMatMulNBitsQkvFusedGraphImpl(static_cast(graph), false); + return CheckMatMulNBitsQkvFusedGraphImpl(static_cast(graph), + /*expect_skip_sln_output=*/false, + /*expect_skip_input=*/false); } Status CheckMatMulNBitsQkvSkipFusedGraph(Graph& graph) { - return CheckMatMulNBitsQkvFusedGraphImpl(static_cast(graph), false); + return CheckMatMulNBitsQkvFusedGraphImpl(static_cast(graph), + /*expect_skip_sln_output=*/false, + /*expect_skip_input=*/true); } Status CheckMatMulNBitsQkvSkipOutputPassthroughFusedGraph(Graph& graph) { - return CheckMatMulNBitsQkvFusedGraphImpl(static_cast(graph), true); + return CheckMatMulNBitsQkvFusedGraphImpl(static_cast(graph), + /*expect_skip_sln_output=*/true, + /*expect_skip_input=*/true); } void BuildMatMulNBitsQkvWebGpuPatternImpl(ModelTestBuilder& builder, bool with_skip_input, bool with_skip_output) { @@ -176,7 +188,9 @@ TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionMatchesUnfusedWebGpuResults } auto check_transformed_graph = [](InferenceSessionWrapper& session) { - ASSERT_STATUS_OK(CheckMatMulNBitsQkvFusedGraphImpl(session.GetGraph(), false)); + ASSERT_STATUS_OK(CheckMatMulNBitsQkvFusedGraphImpl(session.GetGraph(), + /*expect_skip_sln_output=*/false, + /*expect_skip_input=*/false)); }; TransformerTester( @@ -212,7 +226,9 @@ TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionMatchesUnfusedSkipWebGpuRes } auto check_transformed_graph = [](InferenceSessionWrapper& session) { - ASSERT_STATUS_OK(CheckMatMulNBitsQkvFusedGraphImpl(session.GetGraph(), false)); + ASSERT_STATUS_OK(CheckMatMulNBitsQkvFusedGraphImpl(session.GetGraph(), + /*expect_skip_sln_output=*/false, + /*expect_skip_input=*/true)); }; TransformerTester( @@ -253,7 +269,9 @@ TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionMatchesUnfusedSkipWebGpuRes }; auto check_transformed_graph = [](InferenceSessionWrapper& session) { - ASSERT_STATUS_OK(CheckMatMulNBitsQkvFusedGraphImpl(session.GetGraph(), true)); + ASSERT_STATUS_OK(CheckMatMulNBitsQkvFusedGraphImpl(session.GetGraph(), + /*expect_skip_sln_output=*/true, + /*expect_skip_input=*/true)); }; TransformerTester( From 4ac9c8165a3f3e4899840ab231d279b223e4370f Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Sat, 2 May 2026 14:53:40 -0700 Subject: [PATCH 18/26] Fixes --- docs/ContribOperators.md | 186 ++++++++++++++++++ .../webgpu/quantization/matmul_nbits_mlp.cc | 77 ++++++-- .../webgpu/quantization/matmul_nbits_qkv.cc | 163 ++++++++------- .../matmul_nbits_qkv.wgsl.template | 15 ++ .../optimizer/matmul_nbits_mlp_fusion_test.cc | 108 +++++++++- 5 files changed, 459 insertions(+), 90 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 9aa44a1600ae6..f0d0dc1beaba5 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -57,6 +57,8 @@ Do not modify directly.* * com.microsoft.MatMulInteger16 * com.microsoft.MatMulIntegerToFloat * com.microsoft.MatMulNBits + * com.microsoft.MatMulNBitsMlp + * com.microsoft.MatMulNBitsQkv * com.microsoft.MaxpoolWithMask * com.microsoft.MoE * com.microsoft.MulInteger @@ -3189,6 +3191,190 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.MatMulNBitsMlp** + + MatMulNBitsMlp fuses two MatMulNBits projections that share the same input and computes + + gate = MatMulNBits(A, gate_weight) + gate_bias + up = MatMulNBits(A, up_weight) + up_bias + Y = activation(gate) * up + + It can also optionally fuse SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization before the + two projections: + + A_norm = SimplifiedLayerNormalization(A, norm_scale, epsilon) + gate = MatMulNBits(A_norm, gate_weight) + gate_bias + up = MatMulNBits(A_norm, up_weight) + up_bias + Y = activation(gate) * up + + A_norm = SkipSimplifiedLayerNormalization(A, skip, norm_scale, epsilon) + gate = MatMulNBits(A_norm, gate_weight) + gate_bias + up = MatMulNBits(A_norm, up_weight) + up_bias + Y = activation(gate) * up + + This operator is intended for decoder MLP patterns such as Qwen-style gate and up projections, but it remains + semantically valid for both prefill and decode because the output shape is the standard MatMul result shape + derived from the runtime shape of A and the shared attributes K and N. + + The operator contract includes a string attribute describing the fused gate activation. + + When fused from SkipSimplifiedLayerNormalization, the optional residual-sum output may also be materialized: + + A_norm, input_skip_bias_sum = SkipSimplifiedLayerNormalization(A, skip, norm_scale, epsilon) + gate = MatMulNBits(A_norm, gate_weight) + gate_bias + up = MatMulNBits(A_norm, up_weight) + up_bias + Y = activation(gate) * up + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
K : int (required)
+
Input feature dimension shared by both quantized weight matrices.
+
N : int (required)
+
Output feature dimension shared by both quantized weight matrices.
+
accuracy_level : int
+
The minimum accuracy level of input A. It follows the same semantics as MatMulNBits.
+
activation : string (required)
+
Activation applied to the gate projection.
+
bits : int
+
Bit-width used to quantize both weight matrices (valid range: 2~8)
+
block_size : int (required)
+
Size of each quantization block along the K dimension. Must be a power of two and >= 16.
+
epsilon : float
+
Epsilon used by the optional fused (Skip)SimplifiedLayerNormalization. Defaults to 1e-5.
+
+ +#### Inputs (8 - 9) + +
+
A : T1
+
The shared input tensor.
+
skip (optional) : T1
+
Optional skip input used by SkipSimplifiedLayerNormalization.
+
norm_scale (optional) : T1
+
Optional RMSNorm scale with shape [K] used by SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization.
+
gate_B : T2
+
Packed uint8 tensor for the gate projection weights.
+
gate_scales : T1
+
Per-block scaling factors for the gate projection.
+
gate_bias (optional) : T1
+
Optional bias for the gate projection with shape [N].
+
up_B : T2
+
Packed uint8 tensor for the up projection weights.
+
up_scales : T1
+
Per-block scaling factors for the up projection.
+
up_bias (optional) : T1
+
Optional bias for the up projection with shape [N].
+
+ +#### Outputs (1 - 2) + +
+
Y : T1
+
The fused gated MLP output tensor.
+
input_skip_bias_sum (optional) : T1
+
Optional residual-sum output for SkipSimplifiedLayerNormalization.
+
+ +#### Type Constraints + +
+
T1 : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float tensors.
+
T2 : tensor(uint8)
+
Constrain quantized weight types to uint8.
+
+ + +### **com.microsoft.MatMulNBitsQkv** + + MatMulNBitsQkv fuses either SimplifiedLayerNormalization (RMSNorm) + or SkipSimplifiedLayerNormalization with three MatMulNBits projections that share the + same normalized activation. + + A_norm = SimplifiedLayerNormalization(A, norm_scale, epsilon) + Q = MatMulNBits(A_norm, q_weight) + K = MatMulNBits(A_norm, k_weight) + V = MatMulNBits(A_norm, v_weight) + + If skip is provided, the operator computes the SkipSimplifiedLayerNormalization variant + and may also return the input+skip residual sum as output 3. + + This operator is intended as a decode-oriented QKV fusion primitive. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
K : int (required)
+
Input feature dimension shared by the normalized input and all projection weights.
+
Nkv : int (required)
+
Output feature dimension shared by the K and V projections.
+
Nq : int (required)
+
Output feature dimension of the Q projection.
+
accuracy_level : int
+
The minimum accuracy level of input A. It follows the same semantics as MatMulNBits.
+
bits : int
+
Bit-width used to quantize all weight matrices (valid range: 2~8)
+
block_size : int (required)
+
Size of each quantization block along the K dimension. Must be a power of two and >= 16.
+
epsilon : float
+
Epsilon used by the simplified layer norm reduction.
+
+ +#### Inputs + +
+
A : T1
+
The shared input tensor.
+
skip (optional) : T1
+
Optional residual input for SkipSimplifiedLayerNormalization.
+
norm_scale : T1
+
Scale input for the simplified layer norm with shape [K].
+
q_B : T2
+
Packed uint8 tensor for the Q projection weights.
+
q_scales : T1
+
Per-block scaling factors for the Q projection.
+
k_B : T2
+
Packed uint8 tensor for the K projection weights.
+
k_scales : T1
+
Per-block scaling factors for the K projection.
+
v_B : T2
+
Packed uint8 tensor for the V projection weights.
+
v_scales : T1
+
Per-block scaling factors for the V projection.
+
+ +#### Outputs (3 - 4) + +
+
Q : T1
+
The Q projection output tensor.
+
K : T1
+
The K projection output tensor.
+
V : T1
+
The V projection output tensor.
+
input_skip_bias_sum (optional) : T1
+
Optional residual-sum output for SkipSimplifiedLayerNormalization.
+
+ +#### Type Constraints + +
+
T1 : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float tensors.
+
T2 : tensor(uint8)
+
Constrain quantized weight types to uint8.
+
+ + ### **com.microsoft.MaxpoolWithMask** For internal use. diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc index 4effbb82a8a26..9e495a0e2700c 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc @@ -3,6 +3,8 @@ #include "contrib_ops/webgpu/quantization/matmul_nbits_mlp.h" +#include + #include "contrib_ops/webgpu/quantization/matmul_nbits.h" #include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" #include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" @@ -482,6 +484,59 @@ Status MatMulNBitsMlp::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont const bool has_gate_bias = gate_bias != nullptr; const bool has_up_bias = up_bias != nullptr; + + // The fully-fused MLP decode shader binds every weight/scale/bias plus the norm/skip + // tensors as storage buffers. Devices with a tight maxStorageBuffersPerShaderStage + // (notably macOS Metal at 10) cannot bind that many. For those devices we run the layer + // norm separately into a scratch tensor and then dispatch a no-norm variant of the + // decode program (which omits the norm_scale, skip, and skip-output bindings, dropping + // the storage-buffer count from up to 11 down to 8). + // + // Storage-buffer count: input_a + (skip?) + (norm_scale?) + 2 * (weight + scales) + // + output + (skip output?) + (gate_bias?) + (up_bias?) + const uint32_t required_storage_buffers = + 1u // input_a + + (has_skip_input ? 1u : 0u) // skip + + (has_norm_input ? 1u : 0u) // norm_scale + + 4u // gate/up weights + scales + + 1u // output + + (has_skip_output ? 1u : 0u) // skip output + + (has_gate_bias ? 1u : 0u) // gate bias + + (has_up_bias ? 1u : 0u); // up bias + const bool exceeds_storage_buffer_limit = + required_storage_buffers > context.DeviceLimits().maxStorageBuffersPerShaderStage; + + // Optionally pre-normalize a into a scratch tensor and drop the norm/skip bindings + // from the decode program. The user-visible residual passthrough (input_skip_bias_sum) + // is produced by the skip-norm op directly in this path. + std::optional normalized_a_storage; + const Tensor* decode_a = a; + if (exceeds_storage_buffer_limit && has_norm_input) { + normalized_a_storage.emplace(context.CreateGPUTensor(a->DataType(), a->Shape())); + if (has_skip_input) { + ORT_RETURN_IF_ERROR(ApplySkipSimplifiedLayerNorm(a, + skip, + norm_scale, + epsilon_, + context, + &*normalized_a_storage, + input_skip_bias_sum)); + } else { + ORT_RETURN_IF_ERROR(ApplySimplifiedLayerNorm(a, + norm_scale, + epsilon_, + context, + &*normalized_a_storage)); + } + decode_a = &*normalized_a_storage; + } + + // Decode-program-level norm/skip bindings: only used when the device has spare + // storage-buffer slots. Otherwise they are wired to the pre-normalized input above. + const bool decode_has_norm_input = has_norm_input && !exceeds_storage_buffer_limit; + const bool decode_has_skip_input = has_skip_input && !exceeds_storage_buffer_limit; + const bool decode_has_skip_output = has_skip_output && !exceeds_storage_buffer_limit; + uint32_t workgroup_size = 128; uint32_t tile_size = 8; uint32_t tile_size_k_vec = @@ -505,19 +560,19 @@ Status MatMulNBitsMlp::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont MatMulNBitsMlpDecodeProgram program{tile_size, has_gate_bias, has_up_bias, - has_norm_input, - has_skip_input, - has_skip_output, + decode_has_norm_input, + decode_has_skip_input, + decode_has_skip_output, single_scale_weights, tile_size_k_vec, k_unroll_tiles}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize(num_N_tile, 1, batch_count); - program.AddInput({a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); - if (has_skip_input) { + program.AddInput({decode_a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); + if (decode_has_skip_input) { program.AddInput({skip, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); } - if (has_norm_input) { + if (decode_has_norm_input) { program.AddInput({norm_scale, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); } program @@ -534,18 +589,18 @@ Status MatMulNBitsMlp::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont {n_blocks_per_col}, {num_N_tile}, {batch_count}, - {has_skip_input ? onnxruntime::narrow(skip->Shape().Size()) : 0u}, + {decode_has_skip_input ? onnxruntime::narrow(skip->Shape().Size()) : 0u}, {epsilon_}}) .CacheHint(single_scale_weights, has_gate_bias, has_up_bias, - has_norm_input, - has_skip_input, - has_skip_output, + decode_has_norm_input, + decode_has_skip_input, + decode_has_skip_output, tile_size_k_vec, k_unroll_tiles, "decode_4bit"); - if (has_skip_output) { + if (decode_has_skip_output) { program.AddOutput({input_skip_bias_sum, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc index aad6e57d19d22..d4ecc69b44d30 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc @@ -191,6 +191,7 @@ class MatMulNBitsQkvDecodeProgram final bool single_scale_weights, uint32_t tile_size_k_vec, uint32_t k_unroll_tiles, + bool has_norm, bool has_skip_input, bool has_skip_output) : Program{"MatMulNBitsQkvDecode"}, @@ -198,13 +199,19 @@ class MatMulNBitsQkvDecodeProgram final single_scale_weights_(single_scale_weights), tile_size_k_vec_(tile_size_k_vec), k_unroll_tiles_(k_unroll_tiles), + has_norm_(has_norm), has_skip_input_(has_skip_input), - has_skip_output_(has_skip_output) {} + has_skip_output_(has_skip_output) { + // The no-norm variant runs against an already-normalized input tensor and therefore + // never owns the residual skip path nor the residual passthrough output. + ORT_ENFORCE(has_norm_ || (!has_skip_input_ && !has_skip_output_), + "MatMulNBitsQkvDecodeProgram: skip input/output require has_norm=true."); + } Status GenerateShaderCode(ShaderHelper& shader) const override { const auto& a = shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto* skip = has_skip_input_ ? &shader.AddInput("skip", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias) : nullptr; - const auto& norm_scale = shader.AddInput("norm_scale", ShaderUsage::UseValueTypeAlias); + const auto* norm_scale_ptr = has_norm_ ? &shader.AddInput("norm_scale", ShaderUsage::UseValueTypeAlias) : nullptr; const auto& q_b = shader.AddInput("q_b", ShaderUsage::UseValueTypeAlias); const auto& q_scales_b = shader.AddInput("q_scales_b"); const auto& k_b = shader.AddInput("k_b"); @@ -222,6 +229,7 @@ class MatMulNBitsQkvDecodeProgram final ShaderUsage::UseElementTypeAlias); const auto* input_skip_bias_sum = has_skip_output_ ? &shader.AddOutput("input_skip_bias_sum", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias) : nullptr; const auto& skip_var = skip != nullptr ? *skip : a; + const auto& norm_scale_var = norm_scale_ptr != nullptr ? *norm_scale_ptr : a; const auto& input_skip_bias_sum_var = input_skip_bias_sum != nullptr ? *input_skip_bias_sum : q_output; const uint32_t components_a = a.NumComponents(); @@ -232,69 +240,12 @@ class MatMulNBitsQkvDecodeProgram final const uint32_t a_length_per_tile = tile_size_k / components_a; const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec; - if (skip != nullptr) { - if (input_skip_bias_sum != nullptr) { - return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_qkv.wgsl.template", - WGSL_TEMPLATE_PARAMETER(a_length_per_tile, a_length_per_tile), - WGSL_TEMPLATE_PARAMETER(component_a, components_a), - WGSL_TEMPLATE_PARAMETER(component_b, components_b), - WGSL_TEMPLATE_PARAMETER(elements_in_value_b, elements_in_value_b), - WGSL_TEMPLATE_PARAMETER(has_skip_input, has_skip_input_), - WGSL_TEMPLATE_PARAMETER(has_skip_output, has_skip_output_), - WGSL_TEMPLATE_PARAMETER(k_unroll_tiles, k_unroll_tiles_), - WGSL_TEMPLATE_PARAMETER(single_scale_weights, single_scale_weights_), - WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count), - WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), - WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k), - WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), - WGSL_TEMPLATE_VARIABLE(a, a), - WGSL_TEMPLATE_VARIABLE(input_skip_bias_sum, input_skip_bias_sum_var), - WGSL_TEMPLATE_VARIABLE(k_b, k_b), - WGSL_TEMPLATE_VARIABLE(k_output, k_output), - WGSL_TEMPLATE_VARIABLE(k_scales_b, k_scales_b), - WGSL_TEMPLATE_VARIABLE(norm_scale, norm_scale), - WGSL_TEMPLATE_VARIABLE(q_b, q_b), - WGSL_TEMPLATE_VARIABLE(q_output, q_output), - WGSL_TEMPLATE_VARIABLE(q_scales_b, q_scales_b), - WGSL_TEMPLATE_VARIABLE(skip, skip_var), - WGSL_TEMPLATE_VARIABLE(v_b, v_b), - WGSL_TEMPLATE_VARIABLE(v_output, v_output), - WGSL_TEMPLATE_VARIABLE(v_scales_b, v_scales_b)); - } - - return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_qkv.wgsl.template", - WGSL_TEMPLATE_PARAMETER(a_length_per_tile, a_length_per_tile), - WGSL_TEMPLATE_PARAMETER(component_a, components_a), - WGSL_TEMPLATE_PARAMETER(component_b, components_b), - WGSL_TEMPLATE_PARAMETER(elements_in_value_b, elements_in_value_b), - WGSL_TEMPLATE_PARAMETER(has_skip_input, has_skip_input_), - WGSL_TEMPLATE_PARAMETER(has_skip_output, has_skip_output_), - WGSL_TEMPLATE_PARAMETER(k_unroll_tiles, k_unroll_tiles_), - WGSL_TEMPLATE_PARAMETER(single_scale_weights, single_scale_weights_), - WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count), - WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), - WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k), - WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), - WGSL_TEMPLATE_VARIABLE(a, a), - WGSL_TEMPLATE_VARIABLE(input_skip_bias_sum, input_skip_bias_sum_var), - WGSL_TEMPLATE_VARIABLE(k_b, k_b), - WGSL_TEMPLATE_VARIABLE(k_output, k_output), - WGSL_TEMPLATE_VARIABLE(k_scales_b, k_scales_b), - WGSL_TEMPLATE_VARIABLE(norm_scale, norm_scale), - WGSL_TEMPLATE_VARIABLE(q_b, q_b), - WGSL_TEMPLATE_VARIABLE(q_output, q_output), - WGSL_TEMPLATE_VARIABLE(q_scales_b, q_scales_b), - WGSL_TEMPLATE_VARIABLE(skip, skip_var), - WGSL_TEMPLATE_VARIABLE(v_b, v_b), - WGSL_TEMPLATE_VARIABLE(v_output, v_output), - WGSL_TEMPLATE_VARIABLE(v_scales_b, v_scales_b)); - } - return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_qkv.wgsl.template", WGSL_TEMPLATE_PARAMETER(a_length_per_tile, a_length_per_tile), WGSL_TEMPLATE_PARAMETER(component_a, components_a), WGSL_TEMPLATE_PARAMETER(component_b, components_b), WGSL_TEMPLATE_PARAMETER(elements_in_value_b, elements_in_value_b), + WGSL_TEMPLATE_PARAMETER(has_norm, has_norm_), WGSL_TEMPLATE_PARAMETER(has_skip_input, has_skip_input_), WGSL_TEMPLATE_PARAMETER(has_skip_output, has_skip_output_), WGSL_TEMPLATE_PARAMETER(k_unroll_tiles, k_unroll_tiles_), @@ -308,7 +259,7 @@ class MatMulNBitsQkvDecodeProgram final WGSL_TEMPLATE_VARIABLE(k_b, k_b), WGSL_TEMPLATE_VARIABLE(k_output, k_output), WGSL_TEMPLATE_VARIABLE(k_scales_b, k_scales_b), - WGSL_TEMPLATE_VARIABLE(norm_scale, norm_scale), + WGSL_TEMPLATE_VARIABLE(norm_scale, norm_scale_var), WGSL_TEMPLATE_VARIABLE(q_b, q_b), WGSL_TEMPLATE_VARIABLE(q_output, q_output), WGSL_TEMPLATE_VARIABLE(q_scales_b, q_scales_b), @@ -336,6 +287,7 @@ class MatMulNBitsQkvDecodeProgram final bool single_scale_weights_; uint32_t tile_size_k_vec_; uint32_t k_unroll_tiles_; + bool has_norm_; bool has_skip_input_; bool has_skip_output_; }; @@ -420,7 +372,27 @@ Status MatMulNBitsQkv::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont block_size, bits_); - if (would_use_subgroup_unfused || would_use_dp4a_unfused || would_use_wide_tile_unfused || M != 1) { + // The fused MatMulNBitsQkv shader binds every Q/K/V weight + scales tensor and the + // norm/skip tensors as storage buffers. Devices with a tight maxStorageBuffersPerShaderStage + // (notably macOS Metal at 10) cannot bind that many. For those devices we run the layer + // norm separately into a scratch tensor and then dispatch a no-norm variant of the fused + // QKV decode program (which omits the norm_scale, skip, and skip-output bindings, dropping + // the storage-buffer count from up to 13 down to 10). + // + // Storage-buffer count: input_a + (skip?) + norm_scale + 3 * (weight + scales) + // + q/k/v outputs + (skip output?) + const uint32_t required_storage_buffers = + 1u // input_a + + (skip != nullptr ? 1u : 0u) // skip + + 1u // norm_scale + + 6u // q/k/v weights + scales + + 3u // q/k/v outputs + + (input_skip_bias_sum != nullptr ? 1u : 0u); // skip output + const bool exceeds_storage_buffer_limit = + required_storage_buffers > context.DeviceLimits().maxStorageBuffersPerShaderStage; + + if (would_use_subgroup_unfused || would_use_dp4a_unfused || would_use_wide_tile_unfused || + M != 1) { if (skip != nullptr) { return ApplyUnfusedQKVSkipSimplifiedLayerNorm(a, skip, @@ -465,6 +437,32 @@ Status MatMulNBitsQkv::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont v_output); } + // For the partial-fuse path, run [Skip]SimplifiedLayerNormalization into a scratch tensor + // first, then point the decode program at the normalized tensor with the norm/skip bindings + // turned off. The user-visible residual passthrough (input_skip_bias_sum) is produced by the + // skip-norm op directly, so the decode program never needs to write it itself. + std::optional normalized_a_storage; + const Tensor* decode_a = a; + if (exceeds_storage_buffer_limit) { + normalized_a_storage.emplace(context.CreateGPUTensor(a->DataType(), a->Shape())); + if (skip != nullptr) { + ORT_RETURN_IF_ERROR(ApplySkipSimplifiedLayerNorm(a, + skip, + norm_scale, + epsilon_, + context, + &*normalized_a_storage, + input_skip_bias_sum)); + } else { + ORT_RETURN_IF_ERROR(ApplySimplifiedLayerNorm(a, + norm_scale, + epsilon_, + context, + &*normalized_a_storage)); + } + decode_a = &*normalized_a_storage; + } + const uint32_t components_a = GetMaxComponents(K); const uint32_t block_size_per_col = block_size; const uint32_t n_blocks_per_col = (K + block_size_per_col - 1) / block_size_per_col; @@ -485,12 +483,21 @@ Status MatMulNBitsQkv::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b; const uint32_t k_tile_iterations = K / tile_size_k; + // Decode-program-level skip bindings: only used by the fully-fused path. The partial-fuse + // path has already merged the residual upstream and exposed the residual passthrough output + // through the layer-norm op, so the decode program runs with skip_input/skip_output disabled. + const bool decode_has_norm = !exceeds_storage_buffer_limit; + const bool decode_has_skip_input = !exceeds_storage_buffer_limit && skip != nullptr; std::optional input_skip_bias_sum_scratch; - Tensor* decode_input_skip_bias_sum = input_skip_bias_sum; - if (skip != nullptr && decode_input_skip_bias_sum == nullptr) { - input_skip_bias_sum_scratch.emplace(context.CreateGPUTensor(a->DataType(), a->Shape())); - decode_input_skip_bias_sum = &*input_skip_bias_sum_scratch; + Tensor* decode_input_skip_bias_sum = nullptr; + if (decode_has_skip_input) { + decode_input_skip_bias_sum = input_skip_bias_sum; + if (decode_input_skip_bias_sum == nullptr) { + input_skip_bias_sum_scratch.emplace(context.CreateGPUTensor(a->DataType(), a->Shape())); + decode_input_skip_bias_sum = &*input_skip_bias_sum_scratch; + } } + const bool decode_has_skip_output = decode_input_skip_bias_sum != nullptr; uint32_t k_unroll_tiles = 1; if ((K % tile_size_k) == 0) { @@ -507,18 +514,21 @@ Status MatMulNBitsQkv::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont single_scale_weights, tile_size_k_vec, k_unroll_tiles, - skip != nullptr, - decode_input_skip_bias_sum != nullptr}; + decode_has_norm, + decode_has_skip_input, + decode_has_skip_output}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize(num_N_tile, 1, batch_count); program - .AddInput({a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); - if (skip != nullptr) { + .AddInput({decode_a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); + if (decode_has_skip_input) { program.AddInput({skip, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); } + if (decode_has_norm) { + program.AddInput({norm_scale, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); + } program - .AddInputs({{norm_scale, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}, - {q_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + .AddInputs({{q_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, {q_scales, ProgramTensorMetadataDependency::TypeAndRank}, {k_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, {k_scales, ProgramTensorMetadataDependency::TypeAndRank}, @@ -536,7 +546,7 @@ Status MatMulNBitsQkv::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont {n_blocks_per_col}, {num_N_tile}, {batch_count}, - {skip != nullptr ? onnxruntime::narrow(skip->Shape().Size()) : 0u}, + {decode_has_skip_input ? onnxruntime::narrow(skip->Shape().Size()) : 0u}, {epsilon_}}) .CacheHint(Nq, Nkv, @@ -545,10 +555,11 @@ Status MatMulNBitsQkv::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont tile_size_k_vec, k_unroll_tiles, single_scale_weights, - skip != nullptr, - decode_input_skip_bias_sum != nullptr, + decode_has_norm, + decode_has_skip_input, + decode_has_skip_output, "decode_qkv_sln"); - if (decode_input_skip_bias_sum != nullptr) { + if (decode_has_skip_output) { program.AddOutput({decode_input_skip_bias_sum, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template index fe0dcd7fa7a1b..61b50ada36cfa 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template @@ -5,6 +5,7 @@ #param component_a #param component_b #param elements_in_value_b +#param has_norm #param has_skip_input #param has_skip_output #param k_unroll_tiles @@ -16,7 +17,9 @@ #use .getByOffset .setByOffset +#if has_norm var sum_squared_shared : array; +#endif var tile_A : array; var q_inter_results : array, tile_size>; var k_inter_results : array, tile_size>; @@ -71,6 +74,7 @@ fn loadSHMA(batch: u32, b_global_base: u32, kidx: u32, col: u32, inv_std: f32) { let k_offset = kidx / component_a + col; let input_offset = batch * uniforms.K_of_a + k_offset; if (k_offset < uniforms.K_of_a) { +#if has_norm let merged_value = load_merged_input(input_offset); #if has_skip_output if (b_global_base == 0u) { @@ -78,6 +82,12 @@ fn loadSHMA(batch: u32, b_global_base: u32, kidx: u32, col: u32, inv_std: f32) { } #endif tile_A[col] = merged_value * input_a_value_t(input_a_element_t(inv_std)) * norm_scale.getByOffset(k_offset); +#else + // Layer norm has already been applied to `a` upstream; load the pre-normalized value directly. + _ = b_global_base; + _ = inv_std; + tile_A[col] = a.getByOffset(input_offset); +#endif } else { tile_A[col] = input_a_value_t(0); } @@ -183,6 +193,7 @@ $MAIN { } } +#if has_norm var sum_squared_local = 0.0; for (var a_idx = local_idx; a_idx < uniforms.K_of_a; a_idx += workgroup_size_x) { let a_value = load_merged_input(batch * uniforms.K_of_a + a_idx); @@ -210,6 +221,10 @@ $MAIN { } let inv_std = inverseSqrt(sum_squared_shared[0] / f32(uniforms.K) + uniforms.epsilon); +#else + // Layer norm already applied upstream; inv_std is unused but kept in the loadSHMA signature. + let inv_std = 1.0; +#endif #if single_scale_weights let q_scale_b = q_output_element_t(q_scales_b.getByOffset(0)); diff --git a/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc b/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc index cc59679aa4d63..fc4a04b3f0c7e 100644 --- a/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc +++ b/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc @@ -35,6 +35,11 @@ enum class SkipOutputKind { kGraphOutput, }; +enum class BiasKind { + kWithBias, + kNoBias, +}; + void SetWebGpuProvider(Node& node) { node.SetExecutionProviderType(kWebGpuExecutionProvider); } @@ -135,7 +140,8 @@ Status CheckMatMulNBitsMlpSkipOutputPassthroughFusedGraph(const Graph& graph) { void BuildMatMulNBitsMlpWebGpuPatternImpl(ModelTestBuilder& builder, NormAnchorKind norm_anchor_kind, - SkipOutputKind skip_output_kind = SkipOutputKind::kNone) { + SkipOutputKind skip_output_kind = SkipOutputKind::kNone, + BiasKind bias_kind = BiasKind::kWithBias) { constexpr int64_t k = 32; constexpr int64_t n = 8; constexpr int64_t block_size = 32; @@ -158,10 +164,14 @@ void BuildMatMulNBitsMlpWebGpuPatternImpl(ModelTestBuilder& builder, NodeArg* gate_weight = builder.MakeInitializer({n, 1, blob_size}, uint8_t{0}, uint8_t{15}); NodeArg* gate_scale = builder.MakeInitializer({n, 1}, MLFloat16(1.0f), MLFloat16(1.0f)); - NodeArg* gate_bias = builder.MakeInitializer({n}, MLFloat16(0.0f), MLFloat16(0.0f)); + NodeArg* gate_bias = (bias_kind == BiasKind::kWithBias) + ? builder.MakeInitializer({n}, MLFloat16(0.0f), MLFloat16(0.0f)) + : optional_tensor; NodeArg* up_weight = builder.MakeInitializer({n, 1, blob_size}, uint8_t{0}, uint8_t{15}); NodeArg* up_scale = builder.MakeInitializer({n, 1}, MLFloat16(1.0f), MLFloat16(1.0f)); - NodeArg* up_bias = builder.MakeInitializer({n}, MLFloat16(0.0f), MLFloat16(0.0f)); + NodeArg* up_bias = (bias_kind == BiasKind::kWithBias) + ? builder.MakeInitializer({n}, MLFloat16(0.0f), MLFloat16(0.0f)) + : optional_tensor; NodeArg* normalized_input = builder.MakeIntermediate(std::vector{1, k}); NodeArg* gate_out = builder.MakeIntermediate(std::vector{1, n}); @@ -227,6 +237,21 @@ void BuildMatMulNBitsMlpSkipOutputPassthroughWebGpuPattern(ModelTestBuilder& bui BuildMatMulNBitsMlpWebGpuPatternImpl(builder, NormAnchorKind::kSkipSimplified, SkipOutputKind::kGraphOutput); } +void BuildMatMulNBitsMlpSimplifiedWebGpuPatternNoBias(ModelTestBuilder& builder) { + BuildMatMulNBitsMlpWebGpuPatternImpl(builder, NormAnchorKind::kSimplified, SkipOutputKind::kNone, + BiasKind::kNoBias); +} + +void BuildMatMulNBitsMlpSkipWebGpuPatternNoBias(ModelTestBuilder& builder) { + BuildMatMulNBitsMlpWebGpuPatternImpl(builder, NormAnchorKind::kSkipSimplified, SkipOutputKind::kNone, + BiasKind::kNoBias); +} + +void BuildMatMulNBitsMlpSkipOutputPassthroughWebGpuPatternNoBias(ModelTestBuilder& builder) { + BuildMatMulNBitsMlpWebGpuPatternImpl(builder, NormAnchorKind::kSkipSimplified, SkipOutputKind::kGraphOutput, + BiasKind::kNoBias); +} + } // namespace TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionFusesSimplifiedWebGpuPattern) { @@ -342,6 +367,83 @@ TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuRes std::move(webgpu_ep)); } +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSimplifiedWebGpuResultsNoBias) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsMlpSimplifiedFusedGraph(session.GetGraph())); + }; + + TransformerTester( + BuildMatMulNBitsMlpSimplifiedWebGpuPatternNoBias, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + {}, + {}, + std::move(webgpu_ep)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuResultsNoBias) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsMlpSkipFusedGraph(session.GetGraph())); + }; + + TransformerTester( + BuildMatMulNBitsMlpSkipWebGpuPatternNoBias, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + {}, + {}, + std::move(webgpu_ep)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuResultsWithResidualOutputPassthroughNoBias) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto add_session_options = [](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableSpecifiedOptimizers, + "EliminateIdentity")); + }; + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsMlpSkipOutputPassthroughFusedGraph(session.GetGraph())); + }; + + TransformerTester( + BuildMatMulNBitsMlpSkipOutputPassthroughWebGpuPatternNoBias, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + add_session_options, + {}, + std::move(webgpu_ep)); +} + #endif // !defined(DISABLE_CONTRIB_OPS) } // namespace test From 306fba37a0de3e4bc6276a27adf81d52a32f36c6 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Sat, 2 May 2026 17:11:57 -0700 Subject: [PATCH 19/26] Fix --- onnxruntime/test/util/default_providers.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 7e6bc6ae06020..8184916bb479b 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -318,10 +318,14 @@ std::unique_ptr DefaultWebGpuExecutionProvider(bool is_nhwc) std::unique_ptr WebGpuExecutionProviderWithOptions(const ConfigOptions& config_options) { #if defined(USE_WEBGPU) #if defined(ORT_USE_EP_API_ADAPTERS) + // Return nullptr (rather than throwing) when the dynamic plugin EP is either uninitialized + // or initialized as a different EP. Tests interpret nullptr as "WebGPU EP unavailable" and + // skip themselves, which matches the behavior of the non-plugin code path below when + // USE_WEBGPU is undefined. auto ep_name = dynamic_plugin_ep_infra::GetEpName(); - ORT_ENFORCE(ep_name == kWebGpuExecutionProvider, - "Dynamic plugin EP is not the WebGPU EP. Expected \"", kWebGpuExecutionProvider, - "\", got \"", ep_name.value_or(""), "\""); + if (ep_name != kWebGpuExecutionProvider) { + return nullptr; + } return dynamic_plugin_ep_infra::MakeEp(nullptr, &config_options); #else return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider(); From 6c8c7a351885048662a10e41bcad383b6b18e52d Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Sun, 3 May 2026 01:22:37 -0700 Subject: [PATCH 20/26] Use fresh WebGPU EP per session in fusion-vs-unfused tests The shared-EP path through TransformerTester triggers a SEH 0xC0000005 in CI when the EP outlives a per-session profiler whose pointer is still cached on the EP. A separate fix to the WebGPU EP's session_profiler_ lifetime is in flight; meanwhile, switch the 8 MatMulNBits MLP and QKV WebGPU fusion-vs- unfused tests to a small RunWebGpuFusionTransformerTest helper that creates a fresh execution provider per session via a factory lambda. Production code is unchanged. --- .../optimizer/matmul_nbits_mlp_fusion_test.cc | 57 ++++------ .../optimizer/matmul_nbits_qkv_fusion_test.cc | 29 ++--- .../test/optimizer/webgpu_fusion_test_util.h | 103 ++++++++++++++++++ 3 files changed, 135 insertions(+), 54 deletions(-) create mode 100644 onnxruntime/test/optimizer/webgpu_fusion_test_util.h diff --git a/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc b/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc index fc4a04b3f0c7e..01ffbc1c71669 100644 --- a/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc +++ b/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc @@ -13,6 +13,7 @@ #include "test/unittest_util/framework_test_utils.h" #include "test/unittest_util/graph_transform_test_builder.h" #include "test/optimizer/graph_transform_test_fixture.h" +#include "test/optimizer/webgpu_fusion_test_util.h" #include "gtest/gtest.h" @@ -291,8 +292,7 @@ TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionFusesSkipWebGpuPatternWithR } TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSimplifiedWebGpuResults) { - auto webgpu_ep = DefaultWebGpuExecutionProvider(); - if (!webgpu_ep) { + if (!DefaultWebGpuExecutionProvider()) { GTEST_SKIP() << "WebGPU EP unavailable in this build."; } @@ -300,7 +300,7 @@ TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSimplifiedWeb ASSERT_STATUS_OK(CheckMatMulNBitsMlpSimplifiedFusedGraph(session.GetGraph())); }; - TransformerTester( + RunWebGpuFusionTransformerTest( BuildMatMulNBitsMlpSimplifiedWebGpuPattern, check_transformed_graph, TransformerLevel::Level1, @@ -309,14 +309,11 @@ TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSimplifiedWeb 1e-3, 1e-3, std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), - {}, - {}, - std::move(webgpu_ep)); + []() { return DefaultWebGpuExecutionProvider(); }); } TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuResults) { - auto webgpu_ep = DefaultWebGpuExecutionProvider(); - if (!webgpu_ep) { + if (!DefaultWebGpuExecutionProvider()) { GTEST_SKIP() << "WebGPU EP unavailable in this build."; } @@ -324,7 +321,7 @@ TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuRes ASSERT_STATUS_OK(CheckMatMulNBitsMlpSkipFusedGraph(session.GetGraph())); }; - TransformerTester( + RunWebGpuFusionTransformerTest( BuildMatMulNBitsMlpSkipWebGpuPattern, check_transformed_graph, TransformerLevel::Level1, @@ -333,14 +330,11 @@ TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuRes 1e-3, 1e-3, std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), - {}, - {}, - std::move(webgpu_ep)); + []() { return DefaultWebGpuExecutionProvider(); }); } TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuResultsWithResidualOutputPassthrough) { - auto webgpu_ep = DefaultWebGpuExecutionProvider(); - if (!webgpu_ep) { + if (!DefaultWebGpuExecutionProvider()) { GTEST_SKIP() << "WebGPU EP unavailable in this build."; } @@ -353,7 +347,7 @@ TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuRes ASSERT_STATUS_OK(CheckMatMulNBitsMlpSkipOutputPassthroughFusedGraph(session.GetGraph())); }; - TransformerTester( + RunWebGpuFusionTransformerTest( BuildMatMulNBitsMlpSkipOutputPassthroughWebGpuPattern, check_transformed_graph, TransformerLevel::Level1, @@ -362,14 +356,12 @@ TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuRes 1e-3, 1e-3, std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), - add_session_options, - {}, - std::move(webgpu_ep)); + []() { return DefaultWebGpuExecutionProvider(); }, + add_session_options); } TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSimplifiedWebGpuResultsNoBias) { - auto webgpu_ep = DefaultWebGpuExecutionProvider(); - if (!webgpu_ep) { + if (!DefaultWebGpuExecutionProvider()) { GTEST_SKIP() << "WebGPU EP unavailable in this build."; } @@ -377,7 +369,7 @@ TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSimplifiedWeb ASSERT_STATUS_OK(CheckMatMulNBitsMlpSimplifiedFusedGraph(session.GetGraph())); }; - TransformerTester( + RunWebGpuFusionTransformerTest( BuildMatMulNBitsMlpSimplifiedWebGpuPatternNoBias, check_transformed_graph, TransformerLevel::Level1, @@ -386,14 +378,11 @@ TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSimplifiedWeb 1e-3, 1e-3, std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), - {}, - {}, - std::move(webgpu_ep)); + []() { return DefaultWebGpuExecutionProvider(); }); } TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuResultsNoBias) { - auto webgpu_ep = DefaultWebGpuExecutionProvider(); - if (!webgpu_ep) { + if (!DefaultWebGpuExecutionProvider()) { GTEST_SKIP() << "WebGPU EP unavailable in this build."; } @@ -401,7 +390,7 @@ TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuRes ASSERT_STATUS_OK(CheckMatMulNBitsMlpSkipFusedGraph(session.GetGraph())); }; - TransformerTester( + RunWebGpuFusionTransformerTest( BuildMatMulNBitsMlpSkipWebGpuPatternNoBias, check_transformed_graph, TransformerLevel::Level1, @@ -410,14 +399,11 @@ TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuRes 1e-3, 1e-3, std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), - {}, - {}, - std::move(webgpu_ep)); + []() { return DefaultWebGpuExecutionProvider(); }); } TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuResultsWithResidualOutputPassthroughNoBias) { - auto webgpu_ep = DefaultWebGpuExecutionProvider(); - if (!webgpu_ep) { + if (!DefaultWebGpuExecutionProvider()) { GTEST_SKIP() << "WebGPU EP unavailable in this build."; } @@ -430,7 +416,7 @@ TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuRes ASSERT_STATUS_OK(CheckMatMulNBitsMlpSkipOutputPassthroughFusedGraph(session.GetGraph())); }; - TransformerTester( + RunWebGpuFusionTransformerTest( BuildMatMulNBitsMlpSkipOutputPassthroughWebGpuPatternNoBias, check_transformed_graph, TransformerLevel::Level1, @@ -439,9 +425,8 @@ TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuRes 1e-3, 1e-3, std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), - add_session_options, - {}, - std::move(webgpu_ep)); + []() { return DefaultWebGpuExecutionProvider(); }, + add_session_options); } #endif // !defined(DISABLE_CONTRIB_OPS) diff --git a/onnxruntime/test/optimizer/matmul_nbits_qkv_fusion_test.cc b/onnxruntime/test/optimizer/matmul_nbits_qkv_fusion_test.cc index 3f84e630067e8..754c23155bf47 100644 --- a/onnxruntime/test/optimizer/matmul_nbits_qkv_fusion_test.cc +++ b/onnxruntime/test/optimizer/matmul_nbits_qkv_fusion_test.cc @@ -12,6 +12,7 @@ #include "test/unittest_util/framework_test_utils.h" #include "test/unittest_util/graph_transform_test_builder.h" #include "test/optimizer/graph_transform_test_fixture.h" +#include "test/optimizer/webgpu_fusion_test_util.h" #include "gtest/gtest.h" @@ -182,8 +183,7 @@ TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionFusesWebGpuPattern) { } TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionMatchesUnfusedWebGpuResults) { - auto webgpu_ep = DefaultWebGpuExecutionProvider(); - if (!webgpu_ep) { + if (!DefaultWebGpuExecutionProvider()) { GTEST_SKIP() << "WebGPU EP unavailable in this build."; } @@ -193,7 +193,7 @@ TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionMatchesUnfusedWebGpuResults /*expect_skip_input=*/false)); }; - TransformerTester( + RunWebGpuFusionTransformerTest( BuildMatMulNBitsQkvWebGpuPattern, check_transformed_graph, TransformerLevel::Level1, @@ -202,9 +202,7 @@ TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionMatchesUnfusedWebGpuResults 1e-3, 1e-3, std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), - {}, - {}, - std::move(webgpu_ep)); + []() { return DefaultWebGpuExecutionProvider(); }); } TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionFusesSkipWebGpuPattern) { @@ -220,8 +218,7 @@ TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionFusesSkipWebGpuPattern) { } TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionMatchesUnfusedSkipWebGpuResults) { - auto webgpu_ep = DefaultWebGpuExecutionProvider(); - if (!webgpu_ep) { + if (!DefaultWebGpuExecutionProvider()) { GTEST_SKIP() << "WebGPU EP unavailable in this build."; } @@ -231,7 +228,7 @@ TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionMatchesUnfusedSkipWebGpuRes /*expect_skip_input=*/true)); }; - TransformerTester( + RunWebGpuFusionTransformerTest( BuildMatMulNBitsQkvSkipWebGpuPattern, check_transformed_graph, TransformerLevel::Level1, @@ -240,9 +237,7 @@ TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionMatchesUnfusedSkipWebGpuRes 1e-3, 1e-3, std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), - {}, - {}, - std::move(webgpu_ep)); + []() { return DefaultWebGpuExecutionProvider(); }); } TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionFusesSkipWebGpuPatternWithResidualOutputPassthrough) { @@ -258,8 +253,7 @@ TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionFusesSkipWebGpuPatternWithR } TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionMatchesUnfusedSkipWebGpuResultsWithResidualOutputPassthrough) { - auto webgpu_ep = DefaultWebGpuExecutionProvider(); - if (!webgpu_ep) { + if (!DefaultWebGpuExecutionProvider()) { GTEST_SKIP() << "WebGPU EP unavailable in this build."; } @@ -274,7 +268,7 @@ TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionMatchesUnfusedSkipWebGpuRes /*expect_skip_input=*/true)); }; - TransformerTester( + RunWebGpuFusionTransformerTest( BuildMatMulNBitsQkvSkipOutputPassthroughWebGpuPattern, check_transformed_graph, TransformerLevel::Level1, @@ -283,9 +277,8 @@ TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionMatchesUnfusedSkipWebGpuRes 1e-3, 1e-3, std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), - add_session_options, - {}, - std::move(webgpu_ep)); + []() { return DefaultWebGpuExecutionProvider(); }, + add_session_options); } #endif // !defined(DISABLE_CONTRIB_OPS) diff --git a/onnxruntime/test/optimizer/webgpu_fusion_test_util.h b/onnxruntime/test/optimizer/webgpu_fusion_test_util.h new file mode 100644 index 0000000000000..5cf0d63dc710b --- /dev/null +++ b/onnxruntime/test/optimizer/webgpu_fusion_test_util.h @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "core/framework/execution_provider.h" +#include "core/framework/session_options.h" +#include "core/graph/constants.h" +#include "core/graph/model.h" +#include "core/optimizer/graph_transformer.h" +#include "core/session/inference_session.h" +#include "test/compare_ortvalue.h" +#include "test/test_environment.h" +#include "test/unittest_util/graph_transform_test_builder.h" +#include "test/util/include/asserts.h" +#include "test/util/include/inference_session_wrapper.h" + +namespace onnxruntime { +namespace test { + +// Variant of TransformerTester for WebGPU fusion tests that creates a fresh execution provider +// per session via the provided factory, instead of sharing one EP across the baseline and target +// sessions. Sharing a single WebGPU EP across multiple InferenceSessions in series can leave the +// EP holding a dangling pointer to a destroyed session-level profiler; a separate fix to the EP +// addresses that, but using a fresh EP per session also avoids the issue and keeps the fusion PR +// independent of profiler-lifetime changes. +inline void RunWebGpuFusionTransformerTest( + const std::function& build_test_case, + const std::function& check_transformed_graph, + TransformerLevel baseline_level, + TransformerLevel target_level, + int opset_version, + double per_sample_tolerance, + double relative_per_sample_tolerance, + std::unique_ptr transformer, + const std::function()>& ep_factory, + const std::function& add_session_options = {}) { + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = opset_version; + domain_to_version[kMSDomain] = 1; + Model model("WebGpuFusionTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, {}, DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + ASSERT_TRUE(build_test_case); + build_test_case(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + auto run_model = [&](TransformerLevel level, std::vector& fetches, + std::unique_ptr level_transformer) { + SessionOptions session_options; + session_options.graph_optimization_level = level_transformer ? baseline_level : level; + if (add_session_options) { + add_session_options(session_options); + } + + InferenceSessionWrapper session{session_options, GetEnvironment()}; + auto ep = ep_factory(); + ASSERT_TRUE(ep != nullptr) << "ep_factory() returned nullptr"; + ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(ep))); + + ASSERT_STATUS_OK(session.Load(model_data.data(), static_cast(model_data.size()))); + if (level_transformer) { + ASSERT_STATUS_OK(session.RegisterGraphTransformer(std::move(level_transformer), level)); + } + + ASSERT_STATUS_OK(session.Initialize()); + + RunOptions run_options; + ASSERT_STATUS_OK(session.Run(run_options, helper.feeds_, helper.output_names_, &fetches)); + + if (level == target_level && check_transformed_graph) { + check_transformed_graph(session); + } + }; + + std::vector baseline_fetches; + ASSERT_NO_FATAL_FAILURE(run_model(baseline_level, baseline_fetches, /*level_transformer=*/nullptr)); + + std::vector target_fetches; + ASSERT_NO_FATAL_FAILURE(run_model(target_level, target_fetches, std::move(transformer))); + + const size_t num_outputs = baseline_fetches.size(); + ASSERT_EQ(num_outputs, target_fetches.size()); + for (size_t i = 0; i < num_outputs; ++i) { + auto ret = CompareOrtValue(target_fetches[i], baseline_fetches[i], + per_sample_tolerance, relative_per_sample_tolerance, false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; + } +} + +} // namespace test +} // namespace onnxruntime From a90a049ca64b053199a797692bcd590c5ef3850d Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Sun, 10 May 2026 15:15:35 -0700 Subject: [PATCH 21/26] Remove unused file --- ...atmul_nbits_mlp_wide_tile_m1.wgsl.template | 127 ------------------ 1 file changed, 127 deletions(-) delete mode 100644 onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp_wide_tile_m1.wgsl.template diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp_wide_tile_m1.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp_wide_tile_m1.wgsl.template deleted file mode 100644 index 47c292ba0ccf1..0000000000000 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp_wide_tile_m1.wgsl.template +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#param has_gate_bias -#param has_up_bias -#param outputs_per_thread - -#use .getByOffset .setByOffset - -const KAVecSizeForBlock32 = 8u; -const kTileN : u32 = workgroup_size_x * outputs_per_thread; -const kDefaultZeroPoint = output_element_t(8); - -var a_data_tile : array; - -fn load_a(batch : u32, col : u32) -> input_a_value_t { - if (batch < uniforms.batch_count && col < uniforms.K_of_a) { - let offset = batch * uniforms.K_of_a + col; - return a.getByOffset(offset); - } - - return input_a_value_t(); -} - -fn load_gate_b(row : u32, block_idx : u32) -> vec4 { - if (row < uniforms.N && block_idx < uniforms.K_of_b) { - let offset = row * uniforms.K_of_b + block_idx; - return gate_b.getByOffset(offset); - } - - return vec4(); -} - -fn load_up_b(row : u32, block_idx : u32) -> vec4 { - if (row < uniforms.N && block_idx < uniforms.K_of_b) { - let offset = row * uniforms.K_of_b + block_idx; - return up_b.getByOffset(offset); - } - - return vec4(); -} - -fn dequantize_u4_block(packed_data : u32, - scale : output_element_t) -> mat2x4 { - let lower : vec4 = unpack4xU8(packed_data & 0x0F0F0F0Fu); - let upper : vec4 = unpack4xU8((packed_data >> 4u) & 0x0F0F0F0Fu); - - let zero_matrix : mat2x4 = mat2x4( - kDefaultZeroPoint, kDefaultZeroPoint, kDefaultZeroPoint, kDefaultZeroPoint, - kDefaultZeroPoint, kDefaultZeroPoint, kDefaultZeroPoint, kDefaultZeroPoint); - - var dequantized_values : mat2x4 = mat2x4( - output_element_t(lower[0]), output_element_t(upper[0]), - output_element_t(lower[1]), output_element_t(upper[1]), - output_element_t(lower[2]), output_element_t(upper[2]), - output_element_t(lower[3]), output_element_t(upper[3])); - - dequantized_values = (dequantized_values - zero_matrix) * scale; - return dequantized_values; -} - -$MAIN { - let batch = workgroup_id.z; - let col_base = workgroup_id.x * kTileN + local_idx; - - var gate_results : array; - var up_results : array; - for (var output_idx = 0u; output_idx < outputs_per_thread; output_idx++) { - gate_results[output_idx] = output_element_t(0); - up_results[output_idx] = output_element_t(0); - } - - for (var block_idx = 0u; block_idx < uniforms.n_blocks_per_col; block_idx++) { - if (local_idx < KAVecSizeForBlock32) { - a_data_tile[local_idx] = load_a(batch, block_idx * KAVecSizeForBlock32 + local_idx); - } - workgroupBarrier(); - - for (var output_idx = 0u; output_idx < outputs_per_thread; output_idx++) { - let col = col_base + output_idx * workgroup_size_x; - if (col < uniforms.N) { - let gate_scale = gate_scales_b.getByOffset(col * uniforms.n_blocks_per_col + block_idx); - let up_scale = up_scales_b.getByOffset(col * uniforms.n_blocks_per_col + block_idx); - let gate_b_data = load_gate_b(col, block_idx); - let up_b_data = load_up_b(col, block_idx); - - for (var b_idx = 0u; b_idx < 4u; b_idx++) { - let gate_dequantized = dequantize_u4_block(gate_b_data[b_idx], gate_scale); - let up_dequantized = dequantize_u4_block(up_b_data[b_idx], up_scale); - let a_data0 = a_data_tile[b_idx * 2u]; - let a_data1 = a_data_tile[b_idx * 2u + 1u]; - - gate_results[output_idx] += dot(a_data0, gate_dequantized[0]) + - dot(a_data1, gate_dequantized[1]); - up_results[output_idx] += dot(a_data0, up_dequantized[0]) + - dot(a_data1, up_dequantized[1]); - } - } - } - - workgroupBarrier(); - } - - if (batch >= uniforms.batch_count) { - return; - } - - for (var output_idx = 0u; output_idx < outputs_per_thread; output_idx++) { - let col = col_base + output_idx * workgroup_size_x; - if (col >= uniforms.N) { - continue; - } - - var gate_result = gate_results[output_idx]; - var up_result = up_results[output_idx]; -#if has_gate_bias - gate_result += gate_bias[col]; -#endif -#if has_up_bias - up_result += up_bias[col]; -#endif - - let one = output_element_t(1.0); - let silu_value = gate_result * (one / (one + exp(-gate_result))); - output.setByOffset(batch * uniforms.N + col, silu_value * up_result); - } -} // MAIN From 007a78e7a61b7e0ef36cdae6a11ff0626f3a1114 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 11 May 2026 15:29:50 -0700 Subject: [PATCH 22/26] [WebGPU] Extract shared LayerNorm/SkipLayerNorm program runners Refactor in response to PR review feedback. The MatMulNBits MLP and QKV fusion kernels previously each carried their own private copies of the SimplifiedLayerNormalization and SkipSimplifiedLayerNormalization program launchers (`GetOverrideShape` + `ApplySimplifiedLayerNorm` + `ApplySkipSimplifiedLayerNorm`). Extract these into reusable helpers exposed by the existing LayerNorm / SkipLayerNorm kernel sources so fused kernels can drop the duplication. * core/providers/webgpu/nn/layer_norm.{h,cc}: - Expose `RunLayerNormProgram(...)` so other kernels can launch the simplified layer-norm program with consistent uniforms / shape overrides. * contrib_ops/webgpu/bert/skip_layer_norm.{h,cc}: - Expose `RunSkipLayerNormProgram(...)` mirroring the same shape for the SkipSimplifiedLayerNormalization variant. * contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc: - Adopt the shared helpers and delete the local copies. No behavior change; emitted WGSL and dispatch are byte-identical. --- .../webgpu/bert/skip_layer_norm.cc | 31 +++- .../contrib_ops/webgpu/bert/skip_layer_norm.h | 15 ++ .../webgpu/quantization/matmul_nbits_qkv.cc | 135 +++--------------- .../core/providers/webgpu/nn/layer_norm.cc | 26 +++- .../core/providers/webgpu/nn/layer_norm.h | 16 +++ 5 files changed, 100 insertions(+), 123 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc index 54df00fd58d92..04a38479fd6e0 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc @@ -154,8 +154,26 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo auto* output = context.Output(0, x_shape); auto* input_skip_bias_sum = context.Output(3, x_shape); - int64_t data_size = x_shape.Size(); - if (data_size == 0) { + if (x_shape.Size() == 0) { + return Status::OK(); + } + + return RunSkipLayerNormProgram(context, x, skip, gamma, beta, bias, epsilon_, simplified, + output, input_skip_bias_sum); +} + +Status RunSkipLayerNormProgram(ComputeContext& context, + const Tensor* x, + const Tensor* skip, + const Tensor* gamma, + const Tensor* beta, + const Tensor* bias, + float epsilon, + bool simplified, + Tensor* output, + Tensor* input_skip_bias_sum) { + const auto& x_shape = x->Shape(); + if (x_shape.Size() == 0) { return Status::OK(); } @@ -165,17 +183,16 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo const uint32_t norm_count = onnxruntime::narrow(x_shape.SizeToDimension(x_shape.NumDimensions() - 1)); const bool split_hidden_dim = hidden_size % 512 == 0 && norm_count == 1; - const auto skip_shape = skip->Shape(); - const uint32_t skip_size = onnxruntime::narrow(skip_shape.Size()); + const uint32_t skip_size = onnxruntime::narrow(skip->Shape().Size()); - SkipLayerNormProgram program{beta != nullptr, bias != nullptr, epsilon_, hidden_size, has_input_skip_bias_sum, simplified, split_hidden_dim}; + SkipLayerNormProgram program{beta != nullptr, bias != nullptr, epsilon, hidden_size, has_input_skip_bias_sum, simplified, split_hidden_dim}; program .CacheHint(simplified, has_input_skip_bias_sum, split_hidden_dim) .AddInputs({{x, ProgramTensorMetadataDependency::Type, components}}) .AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}}) .AddInputs({{gamma, ProgramTensorMetadataDependency::Type, components}}) .AddOutputs({{output, ProgramTensorMetadataDependency::None, components}}) - .SetDispatchGroupSize(onnxruntime::narrow(ceil(1.0 * data_size / hidden_size))) + .SetDispatchGroupSize(onnxruntime::narrow(ceil(1.0 * x_shape.Size() / hidden_size))) .AddUniformVariables({ {static_cast(components)}, }) @@ -183,7 +200,7 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo {static_cast(hidden_size)}, }) .AddUniformVariables({ - {static_cast(epsilon_)}, + {static_cast(epsilon)}, }) .AddUniformVariables({ {static_cast(skip_size)}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h index bfaec1c3d0d79..0430074bb2ae0 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h @@ -60,6 +60,21 @@ class SkipLayerNorm final : public WebGpuKernel { float epsilon_; }; +// Configures and dispatches a SkipLayerNormProgram. Centralizes program-setup logic +// (uniform variables, components, split_hidden_dim heuristic, workgroup sizing) so callers +// other than the SkipLayerNorm kernel (e.g. fused MatMulNBits ops) do not need to duplicate it. +// `beta`, `bias` and `input_skip_bias_sum` may be nullptr. +Status RunSkipLayerNormProgram(ComputeContext& context, + const Tensor* x, + const Tensor* skip, + const Tensor* gamma, + const Tensor* beta, + const Tensor* bias, + float epsilon, + bool simplified, + Tensor* output, + Tensor* input_skip_bias_sum); + } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc index d4ecc69b44d30..4d46e03fba20c 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc @@ -23,105 +23,6 @@ namespace webgpu { namespace { -TensorShape GetOverrideShape(const TensorShape& shape, int components) { - return TensorShape{shape.Size() / components}; -} - -Status ApplySimplifiedLayerNorm(const Tensor* x, - const Tensor* scale, - float epsilon, - onnxruntime::webgpu::ComputeContext& context, - Tensor* y) { - const auto& x_shape = x->Shape(); - if (x_shape.Size() == 0) { - return Status::OK(); - } - - const int64_t norm_size = x_shape[x_shape.NumDimensions() - 1]; - const uint32_t norm_count = onnxruntime::narrow(x_shape.Size() / norm_size); - const int components = GetMaxComponents(norm_size); - const uint32_t norm_size_vectorized = onnxruntime::narrow((norm_size + components - 1) / components); - const bool split_norm_dim = norm_size % 512 == 0 && norm_count == 1; - - onnxruntime::webgpu::LayerNormProgram program{/*has_bias=*/false, - /*simplified=*/true, - /*has_mean_output=*/false, - /*has_inv_std_dev_output=*/false, - split_norm_dim}; - - program.CacheHint(components, true, split_norm_dim) - .AddInputs({{x, ProgramTensorMetadataDependency::Type, GetOverrideShape(x_shape, components), components}, - {scale, ProgramTensorMetadataDependency::Type, GetOverrideShape(scale->Shape(), components), components}}) - .AddOutputs({{y, ProgramTensorMetadataDependency::None, GetOverrideShape(y->Shape(), components), components}}) - .AddUniformVariables({{static_cast(components)}, - {norm_count}, - {static_cast(norm_size)}, - {norm_size_vectorized}, - {epsilon}}); - - if (split_norm_dim) { - const uint32_t workgroup_size_x = 128; - const uint32_t dispatch_size_x = onnxruntime::narrow(norm_size / (workgroup_size_x * components)); - program.SetDispatchGroupSize(dispatch_size_x, 1, 1) - .SetWorkgroupSize(workgroup_size_x); - } else { - program.SetDispatchGroupSize(norm_count); - } - - return context.RunProgram(program); -} - -Status ApplySkipSimplifiedLayerNorm(const Tensor* x, - const Tensor* skip, - const Tensor* scale, - float epsilon, - onnxruntime::webgpu::ComputeContext& context, - Tensor* y, - Tensor* input_skip_bias_sum) { - const auto& x_shape = x->Shape(); - if (x_shape.Size() == 0) { - return Status::OK(); - } - - const uint32_t hidden_size = onnxruntime::narrow(x_shape[x_shape.NumDimensions() - 1]); - const int components = GetMaxComponents(hidden_size); - const uint32_t norm_count = onnxruntime::narrow(x_shape.SizeToDimension(x_shape.NumDimensions() - 1)); - const bool split_hidden_dim = hidden_size % 512 == 0 && norm_count == 1; - const uint32_t skip_size = onnxruntime::narrow(skip->Shape().Size()); - - SkipLayerNormProgram program{/*hasBeta=*/false, - /*hasBias=*/false, - epsilon, - hidden_size, - input_skip_bias_sum != nullptr, - /*simplified=*/true, - split_hidden_dim}; - program - .CacheHint(/*simplified=*/true, input_skip_bias_sum != nullptr, split_hidden_dim) - .AddInputs({{x, ProgramTensorMetadataDependency::Type, components}}) - .AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}}) - .AddInputs({{scale, ProgramTensorMetadataDependency::Type, components}}) - .AddOutputs({{y, ProgramTensorMetadataDependency::None, components}}) - .SetDispatchGroupSize(onnxruntime::narrow(ceil(1.0 * x_shape.Size() / hidden_size))) - .AddUniformVariables({{static_cast(components)}}) - .AddUniformVariables({{hidden_size}}) - .AddUniformVariables({{epsilon}}) - .AddUniformVariables({{skip_size}}); - - if (split_hidden_dim) { - const uint32_t workgroup_size_x = 128; - const uint32_t dispatch_size_x = (input_skip_bias_sum != nullptr ? 2u : 1u) * hidden_size / (workgroup_size_x * components); - program.SetDispatchGroupSize(dispatch_size_x, 1, 1) - .SetWorkgroupSize(workgroup_size_x); - } - - if (input_skip_bias_sum != nullptr) { - program.AddOutputs({{input_skip_bias_sum, ProgramTensorMetadataDependency::None, components}}); - } - - return context.RunProgram(program); -} - Status ApplyUnfusedQKVSimplifiedLayerNorm(const Tensor* a, const Tensor* norm_scale, const Tensor* q_b, @@ -142,7 +43,12 @@ Status ApplyUnfusedQKVSimplifiedLayerNorm(const Tensor* a, Tensor* k_output, Tensor* v_output) { Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape()); - ORT_RETURN_IF_ERROR(ApplySimplifiedLayerNorm(a, norm_scale, epsilon, context, &normalized_a)); + const auto& a_shape = a->Shape(); + const int64_t norm_size = a_shape[a_shape.NumDimensions() - 1]; + const uint32_t norm_count = onnxruntime::narrow(a_shape.Size() / norm_size); + ORT_RETURN_IF_ERROR(onnxruntime::webgpu::RunLayerNormProgram( + context, a, norm_scale, /*bias=*/nullptr, epsilon, norm_count, norm_size, + /*simplified=*/true, &normalized_a, /*mean=*/nullptr, /*inv_std_dev=*/nullptr)); ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, q_b, q_scales, nullptr, nullptr, K, Nq, block_size, accuracy_level, bits, context, q_output)); ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, k_b, k_scales, nullptr, nullptr, @@ -174,7 +80,10 @@ Status ApplyUnfusedQKVSkipSimplifiedLayerNorm(const Tensor* a, Tensor* v_output, Tensor* input_skip_bias_sum) { Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape()); - ORT_RETURN_IF_ERROR(ApplySkipSimplifiedLayerNorm(a, skip, norm_scale, epsilon, context, &normalized_a, input_skip_bias_sum)); + ORT_RETURN_IF_ERROR(RunSkipLayerNormProgram(context, a, skip, norm_scale, + /*beta=*/nullptr, /*bias=*/nullptr, + epsilon, /*simplified=*/true, + &normalized_a, input_skip_bias_sum)); ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, q_b, q_scales, nullptr, nullptr, K, Nq, block_size, accuracy_level, bits, context, q_output)); ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, k_b, k_scales, nullptr, nullptr, @@ -446,19 +355,19 @@ Status MatMulNBitsQkv::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont if (exceeds_storage_buffer_limit) { normalized_a_storage.emplace(context.CreateGPUTensor(a->DataType(), a->Shape())); if (skip != nullptr) { - ORT_RETURN_IF_ERROR(ApplySkipSimplifiedLayerNorm(a, - skip, - norm_scale, - epsilon_, - context, - &*normalized_a_storage, - input_skip_bias_sum)); + ORT_RETURN_IF_ERROR(RunSkipLayerNormProgram(context, a, skip, norm_scale, + /*beta=*/nullptr, /*bias=*/nullptr, + epsilon_, /*simplified=*/true, + &*normalized_a_storage, + input_skip_bias_sum)); } else { - ORT_RETURN_IF_ERROR(ApplySimplifiedLayerNorm(a, - norm_scale, - epsilon_, - context, - &*normalized_a_storage)); + const auto& a_shape = a->Shape(); + const int64_t norm_size = a_shape[a_shape.NumDimensions() - 1]; + const uint32_t norm_count = onnxruntime::narrow(a_shape.Size() / norm_size); + ORT_RETURN_IF_ERROR(onnxruntime::webgpu::RunLayerNormProgram( + context, a, norm_scale, /*bias=*/nullptr, epsilon_, norm_count, norm_size, + /*simplified=*/true, &*normalized_a_storage, /*mean=*/nullptr, + /*inv_std_dev=*/nullptr)); } decode_a = &*normalized_a_storage; } diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc index 7d4ae8c2197ff..3afedea30adaf 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -162,8 +162,6 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions()); const uint32_t norm_count = onnxruntime::narrow(x_shape.SizeToDimension(axis)); const int64_t norm_size = x_shape.SizeFromDimension(axis); - const int components = GetMaxComponents(norm_size); - const uint32_t norm_size_vectorized = onnxruntime::narrow((norm_size + components - 1) / components); const auto scale_size = scale->Shape().Size(); const auto bias_size = (bias) ? bias->Shape().Size() : 0; @@ -192,6 +190,28 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex return Status::OK(); } + return RunLayerNormProgram(context, x, scale, bias, epsilon_, norm_count, norm_size, + simplified, y, mean, inv_std_dev); +} + +Status RunLayerNormProgram(ComputeContext& context, + const Tensor* x, + const Tensor* scale, + const Tensor* bias, + float epsilon, + uint32_t norm_count, + int64_t norm_size, + bool simplified, + Tensor* y, + Tensor* mean, + Tensor* inv_std_dev) { + if (x->Shape().Size() == 0) { + return Status::OK(); + } + + const int components = GetMaxComponents(norm_size); + const uint32_t norm_size_vectorized = onnxruntime::narrow((norm_size + components - 1) / components); + // Check if we should use split norm dimension optimization const bool split_norm_dim = norm_size % 512 == 0 && norm_count == 1; @@ -215,7 +235,7 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex {static_cast(norm_size_vectorized)}, }) .AddUniformVariables({ - {static_cast(epsilon_)}, + {static_cast(epsilon)}, }); if (split_norm_dim) { diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.h b/onnxruntime/core/providers/webgpu/nn/layer_norm.h index 112b152d37130..a6323dc7721d4 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.h +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.h @@ -56,5 +56,21 @@ class LayerNorm final : public WebGpuKernel { int64_t stash_type_; }; +// Configures and dispatches a LayerNormProgram. Centralizes the program-setup logic +// (uniform variables, components, split_norm_dim heuristic, workgroup sizing) so callers +// other than the LayerNorm kernel (e.g. fused MatMulNBits ops) do not need to duplicate it. +// `bias`, `mean` and `inv_std_dev` may be nullptr. +Status RunLayerNormProgram(ComputeContext& context, + const Tensor* x, + const Tensor* scale, + const Tensor* bias, + float epsilon, + uint32_t norm_count, + int64_t norm_size, + bool simplified, + Tensor* y, + Tensor* mean, + Tensor* inv_std_dev); + } // namespace webgpu } // namespace onnxruntime From 37db5b87c61dde165ceab213e718a3e7271dd302 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 11 May 2026 15:30:14 -0700 Subject: [PATCH 23/26] [WebGPU] MatMulNBitsMlp: adopt shared norm helpers + activation enum Two coupled cleanups to the MatMulNBitsMlp kernel, kept together because they touch the same file: 1. Adopt the shared `RunLayerNormProgram` / `RunSkipLayerNormProgram` helpers introduced in the prior commit. Deletes the local copies of `GetOverrideShape`, `ApplySimplifiedLayerNorm`, and `ApplySkipSimplifiedLayerNorm`. No behavior change. 2. Introduce a small `MlpActivationKind` enum so the kernel can later gain GELU / GELU+Cast support (e.g. for Gemma-style MLPs) without reshaping the call paths or schema. Today the enum has a single value, `Silu = 0`, and the emitted WGSL is byte-identical to before. * matmul_nbits_mlp.h: - Add `MlpActivationKind` enum and `ParseMlpActivation()`. Kernel stores the parsed kind in `activation_kind_`. * matmul_nbits_mlp.cc: - Thread `MlpActivationKind` through `MatMulNBitsMlpProgram`, `MatMulNBitsMlpDecodeProgram`, and `ApplyUnfusedMlp`. Include the kind in each program's CacheHint. - Add `EmitGateActivationExpr()` so the inline kernel emits the activation expression via a single helper; today returns the SiLU expression. - While here, collapse the four identical `WGSL_TEMPLATE_APPLY` branches in `MatMulNBitsMlpDecodeProgram::GenerateShaderCode` into one call. Roughly 120 lines removed; emitted WGSL unchanged. * matmul_nbits_mlp.wgsl.template: - Add `#param activation_kind`. Wrap SiLU emission in `#if activation_kind == 0 ... #endif` and produce the activated value through a single `activated_value` binding so additional activations can be added with a new `#elif` branch. The schema already declares `activation` as a generic `STRING`, so no schema change is required. --- .../webgpu/quantization/matmul_nbits_mlp.cc | 272 +++++------------- .../webgpu/quantization/matmul_nbits_mlp.h | 23 +- .../matmul_nbits_mlp.wgsl.template | 14 +- 3 files changed, 99 insertions(+), 210 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc index 9e495a0e2700c..8dd454083c9b3 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc @@ -21,109 +21,32 @@ namespace onnxruntime { namespace contrib { namespace webgpu { -namespace { - -constexpr uint32_t kFusedDecodeFastPathBits = 4u; -constexpr uint32_t kFusedDecodeFastPathBlockSize = 32u; - -TensorShape GetOverrideShape(const TensorShape& shape, int components) { - return TensorShape{shape.Size() / components}; -} - -Status ApplySimplifiedLayerNorm(const Tensor* x, - const Tensor* scale, - float epsilon, - onnxruntime::webgpu::ComputeContext& context, - Tensor* y) { - const auto& x_shape = x->Shape(); - if (x_shape.Size() == 0) { +Status ParseMlpActivation(std::string_view name, MlpActivationKind* out) { + if (name == "silu") { + *out = MlpActivationKind::Silu; return Status::OK(); } - - const int64_t norm_size = x_shape[x_shape.NumDimensions() - 1]; - const uint32_t norm_count = onnxruntime::narrow(x_shape.Size() / norm_size); - const int components = GetMaxComponents(norm_size); - const uint32_t norm_size_vectorized = onnxruntime::narrow((norm_size + components - 1) / components); - const bool split_norm_dim = norm_size % 512 == 0 && norm_count == 1; - - onnxruntime::webgpu::LayerNormProgram program{/*has_bias=*/false, - /*simplified=*/true, - /*has_mean_output=*/false, - /*has_inv_std_dev_output=*/false, - split_norm_dim}; - - program.CacheHint(components, true, split_norm_dim) - .AddInputs({{x, ProgramTensorMetadataDependency::Type, GetOverrideShape(x_shape, components), components}, - {scale, ProgramTensorMetadataDependency::Type, GetOverrideShape(scale->Shape(), components), components}}) - .AddOutputs({{y, ProgramTensorMetadataDependency::None, GetOverrideShape(y->Shape(), components), components}}) - .AddUniformVariables({{static_cast(components)}, - {norm_count}, - {static_cast(norm_size)}, - {norm_size_vectorized}, - {epsilon}}); - - if (split_norm_dim) { - const uint32_t workgroup_size_x = 128; - const uint32_t dispatch_size_x = onnxruntime::narrow(norm_size / (workgroup_size_x * components)); - program.SetDispatchGroupSize(dispatch_size_x, 1, 1) - .SetWorkgroupSize(workgroup_size_x); - } else { - program.SetDispatchGroupSize(norm_count); - } - - return context.RunProgram(program); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "MatMulNBitsMlp: activation '", name, "' is not supported."); } -Status ApplySkipSimplifiedLayerNorm(const Tensor* x, - const Tensor* skip, - const Tensor* scale, - float epsilon, - onnxruntime::webgpu::ComputeContext& context, - Tensor* y, - Tensor* input_skip_bias_sum) { - const auto& x_shape = x->Shape(); - if (x_shape.Size() == 0) { - return Status::OK(); - } +namespace { - const uint32_t hidden_size = onnxruntime::narrow(x_shape[x_shape.NumDimensions() - 1]); - const int components = GetMaxComponents(hidden_size); - const uint32_t norm_count = onnxruntime::narrow(x_shape.SizeToDimension(x_shape.NumDimensions() - 1)); - const bool split_hidden_dim = hidden_size % 512 == 0 && norm_count == 1; - const uint32_t skip_size = onnxruntime::narrow(skip->Shape().Size()); - - SkipLayerNormProgram program{/*hasBeta=*/false, - /*hasBias=*/false, - epsilon, - hidden_size, - input_skip_bias_sum != nullptr, - /*simplified=*/true, - split_hidden_dim}; - program - .CacheHint(/*simplified=*/true, input_skip_bias_sum != nullptr, split_hidden_dim) - .AddInputs({{x, ProgramTensorMetadataDependency::Type, components}}) - .AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}}) - .AddInputs({{scale, ProgramTensorMetadataDependency::Type, components}}) - .AddOutputs({{y, ProgramTensorMetadataDependency::None, components}}) - .SetDispatchGroupSize(onnxruntime::narrow(ceil(1.0 * x_shape.Size() / hidden_size))) - .AddUniformVariables({{static_cast(components)}}) - .AddUniformVariables({{hidden_size}}) - .AddUniformVariables({{epsilon}}) - .AddUniformVariables({{skip_size}}); - - if (split_hidden_dim) { - const uint32_t workgroup_size_x = 128; - const uint32_t dispatch_size_x = (input_skip_bias_sum != nullptr ? 2u : 1u) * hidden_size / - (workgroup_size_x * components); - program.SetDispatchGroupSize(dispatch_size_x, 1, 1) - .SetWorkgroupSize(workgroup_size_x); - } +constexpr uint32_t kFusedDecodeFastPathBits = 4u; +constexpr uint32_t kFusedDecodeFastPathBlockSize = 32u; - if (input_skip_bias_sum != nullptr) { - program.AddOutputs({{input_skip_bias_sum, ProgramTensorMetadataDependency::None, components}}); +// Emits the WGSL expression that applies the gate activation. The result must use +// the variable names produced by the inline kernel (`gate_value`) and the shader +// template (`gate_output_value`), so callers pass the gate operand variable name. +// Adding a new activation here is the kernel-side counterpart to extending the +// MlpActivationKind enum. +std::string EmitGateActivationExpr(MlpActivationKind kind, std::string_view gate_var) { + switch (kind) { + case MlpActivationKind::Silu: + // SiLU(x) = x * sigmoid(x) + return std::string{gate_var} + " * (one / (one + exp(-" + std::string{gate_var} + ")))"; } - - return context.RunProgram(program); + ORT_THROW("MatMulNBitsMlp: unhandled MlpActivationKind ", static_cast(kind)); } class MatMulNBitsMlpDecodeProgram final : public Program { @@ -136,7 +59,8 @@ class MatMulNBitsMlpDecodeProgram final : public Program(activation_kind_)), WGSL_TEMPLATE_PARAMETER(component_a, components_a), WGSL_TEMPLATE_PARAMETER(component_b, components_b), WGSL_TEMPLATE_PARAMETER(elements_in_value_b, elements_in_value_b), @@ -316,11 +159,15 @@ class MatMulNBitsMlpDecodeProgram final : public Program { public: - MatMulNBitsMlpProgram() : Program{"MatMulNBitsMlp"} {} + explicit MatMulNBitsMlpProgram(MlpActivationKind activation_kind) + : Program{"MatMulNBitsMlp"}, activation_kind_(activation_kind) { + CacheHint(static_cast(activation_kind_)); + } Status GenerateShaderCode(ShaderHelper& shader) const override { const auto& gate = shader.AddInput("gate", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); @@ -331,13 +178,16 @@ class MatMulNBitsMlpProgram final : public Program { << "let gate_value = " << gate.GetByOffset("global_idx") << ";\n" << "let up_value = " << up.GetByOffset("global_idx") << ";\n" << "let one = output_value_t(1.0);\n" - << "let silu_value = gate_value * (one / (one + exp(-gate_value)));\n" - << output.SetByOffset("global_idx", "silu_value * up_value"); + << "let activated_value = " << EmitGateActivationExpr(activation_kind_, "gate_value") << ";\n" + << output.SetByOffset("global_idx", "activated_value * up_value"); return Status::OK(); } WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + MlpActivationKind activation_kind_; }; Status ApplyUnfusedMlp(const Tensor* a, @@ -352,6 +202,7 @@ Status ApplyUnfusedMlp(const Tensor* a, int64_t block_size, int64_t accuracy_level, int64_t bits, + MlpActivationKind activation_kind, onnxruntime::webgpu::ComputeContext& context, Tensor* y) { MatMulComputeHelper helper; @@ -367,7 +218,7 @@ Status ApplyUnfusedMlp(const Tensor* a, const uint32_t data_size = onnxruntime::narrow(y->Shape().Size()); const uint32_t vec_size = (data_size + 3u) / 4u; - MatMulNBitsMlpProgram program; + MatMulNBitsMlpProgram program{activation_kind}; program .AddInputs({{&gate_output, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, 4}, {&up_output, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, 4}}) @@ -514,19 +365,20 @@ Status MatMulNBitsMlp::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont if (exceeds_storage_buffer_limit && has_norm_input) { normalized_a_storage.emplace(context.CreateGPUTensor(a->DataType(), a->Shape())); if (has_skip_input) { - ORT_RETURN_IF_ERROR(ApplySkipSimplifiedLayerNorm(a, - skip, - norm_scale, - epsilon_, - context, - &*normalized_a_storage, - input_skip_bias_sum)); + ORT_RETURN_IF_ERROR(RunSkipLayerNormProgram(context, a, skip, norm_scale, + /*beta=*/nullptr, + /*bias=*/nullptr, + epsilon_, /*simplified=*/true, + &*normalized_a_storage, + input_skip_bias_sum)); } else { - ORT_RETURN_IF_ERROR(ApplySimplifiedLayerNorm(a, - norm_scale, - epsilon_, - context, - &*normalized_a_storage)); + const auto& a_shape = a->Shape(); + const int64_t norm_size = a_shape[a_shape.NumDimensions() - 1]; + const uint32_t norm_count = onnxruntime::narrow(a_shape.Size() / norm_size); + ORT_RETURN_IF_ERROR(onnxruntime::webgpu::RunLayerNormProgram( + context, a, norm_scale, /*bias=*/nullptr, epsilon_, norm_count, norm_size, + /*simplified=*/true, &*normalized_a_storage, /*mean=*/nullptr, + /*inv_std_dev=*/nullptr)); } decode_a = &*normalized_a_storage; } @@ -565,7 +417,8 @@ Status MatMulNBitsMlp::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont decode_has_skip_output, single_scale_weights, tile_size_k_vec, - k_unroll_tiles}; + k_unroll_tiles, + activation_kind_}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize(num_N_tile, 1, batch_count); program.AddInput({decode_a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); @@ -599,6 +452,7 @@ Status MatMulNBitsMlp::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont decode_has_skip_output, tile_size_k_vec, k_unroll_tiles, + static_cast(activation_kind_), "decode_4bit"); if (decode_has_skip_output) { program.AddOutput({input_skip_bias_sum, @@ -617,8 +471,10 @@ Status MatMulNBitsMlp::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont if (skip != nullptr) { Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape()); - ORT_RETURN_IF_ERROR(ApplySkipSimplifiedLayerNorm(a, skip, norm_scale, epsilon_, - context, &normalized_a, input_skip_bias_sum)); + ORT_RETURN_IF_ERROR(RunSkipLayerNormProgram(context, a, skip, norm_scale, + /*beta=*/nullptr, /*bias=*/nullptr, + epsilon_, /*simplified=*/true, + &normalized_a, input_skip_bias_sum)); return ApplyUnfusedMlp(&normalized_a, gate_b, gate_scales, @@ -631,13 +487,19 @@ Status MatMulNBitsMlp::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont block_size_, accuracy_level_, bits_, + activation_kind_, context, y); } if (norm_scale != nullptr) { Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape()); - ORT_RETURN_IF_ERROR(ApplySimplifiedLayerNorm(a, norm_scale, epsilon_, context, &normalized_a)); + const auto& a_shape = a->Shape(); + const int64_t norm_size = a_shape[a_shape.NumDimensions() - 1]; + const uint32_t norm_count = onnxruntime::narrow(a_shape.Size() / norm_size); + ORT_RETURN_IF_ERROR(onnxruntime::webgpu::RunLayerNormProgram( + context, a, norm_scale, /*bias=*/nullptr, epsilon_, norm_count, norm_size, + /*simplified=*/true, &normalized_a, /*mean=*/nullptr, /*inv_std_dev=*/nullptr)); return ApplyUnfusedMlp(&normalized_a, gate_b, gate_scales, @@ -650,6 +512,7 @@ Status MatMulNBitsMlp::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont block_size_, accuracy_level_, bits_, + activation_kind_, context, y); } @@ -666,6 +529,7 @@ Status MatMulNBitsMlp::ComputeInternal(onnxruntime::webgpu::ComputeContext& cont block_size_, accuracy_level_, bits_, + activation_kind_, context, y); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h index c6ce500980ee9..002ebca4c0e54 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h @@ -3,8 +3,11 @@ #pragma once +#include #include +#include +#include "core/common/status.h" #include "core/providers/webgpu/webgpu_kernel.h" namespace onnxruntime { @@ -14,6 +17,17 @@ namespace webgpu { using namespace onnxruntime::webgpu; using onnxruntime::webgpu::ComputeContext; +// Gate activation applied between the gate and up MatMulNBits projections. +// Currently only SiLU is supported; future activations (e.g. GELU for Gemma-style +// gated MLPs) can be added here and threaded through the kernel and shader template. +enum class MlpActivationKind : uint32_t { + Silu = 0, +}; + +// Parses the `activation` attribute string into MlpActivationKind. Returns a non-OK +// Status for unsupported activations so the kernel rejects unknown values up front. +Status ParseMlpActivation(std::string_view name, MlpActivationKind* out); + class MatMulNBitsMlp final : public WebGpuKernel { public: explicit MatMulNBitsMlp(const OpKernelInfo& info) : WebGpuKernel(info) { @@ -23,12 +37,13 @@ class MatMulNBitsMlp final : public WebGpuKernel { bits_ = info.GetAttr("bits"); accuracy_level_ = info.GetAttrOrDefault("accuracy_level", 4); epsilon_ = info.GetAttrOrDefault("epsilon", 1e-5f); - ORT_ENFORCE(info.GetAttr("activation", &activation_).IsOK(), + std::string activation; + ORT_ENFORCE(info.GetAttr("activation", &activation).IsOK(), "MatMulNBitsMlp requires the 'activation' attribute."); + ORT_ENFORCE(ParseMlpActivation(activation, &activation_kind_).IsOK(), + "MatMulNBitsMlp: unsupported activation '", activation, "'."); ORT_ENFORCE(bits_ == 4 || bits_ == 8 || bits_ == 2, "Only 4b/8b/2b quantization is supported for MatMulNBitsMlp op."); - ORT_ENFORCE(activation_ == "silu", - "MatMulNBitsMlp currently only supports activation='silu'."); } Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; @@ -40,7 +55,7 @@ class MatMulNBitsMlp final : public WebGpuKernel { int64_t accuracy_level_; int64_t bits_; float epsilon_; - std::string activation_; + MlpActivationKind activation_kind_; }; } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template index a183962b230da..88d53e04177a8 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template @@ -16,6 +16,12 @@ #param has_gate_bias #param has_up_bias #param k_unroll_tiles +// Gate activation applied between gate and up projections. +// Mirrors MlpActivationKind in matmul_nbits_mlp.h: +// 0 = SiLU (only value currently supported) +// New activations are added by extending the enum, the EmitGateActivationExpr +// helper, the fusion matcher, and the activation block below. +#param activation_kind #use .getByOffset .setByOffset @@ -252,8 +258,12 @@ $MAIN { up_output_value += up_bias[b_global]; #endif let one = output_element_t(1.0); - let silu_value = gate_output_value * (one / (one + exp(-gate_output_value))); - output.setByOffset(output_idx, silu_value * up_output_value); +#if activation_kind == 0 + // SiLU(x) = x * sigmoid(x). New activations are added with additional + // `#elif activation_kind == N` blocks (must match MlpActivationKind). + let activated_value = gate_output_value * (one / (one + exp(-gate_output_value))); +#endif + output.setByOffset(output_idx, activated_value * up_output_value); } } } // MAIN From 2c1a2a368f01f89e9fdfe0785dbf653ccff26d42 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 11 May 2026 15:30:39 -0700 Subject: [PATCH 24/26] [WebGPU] MatMulNBitsMlpFusion: match fused-QuickGelu MLP shape After QuickGeluFusion is enabled for the WebGPU EP (upstream PR #28410), the SwiGLU gate subgraph `gate * Sigmoid(gate)` is collapsed into a single `com.microsoft::QuickGelu(gate, alpha=1.0)` node before MatMulNBitsMlpFusion runs. Without this change, the MLP fusion would silently stop firing for Qwen3 / Llama / Phi style WebGPU models. * core/optimizer/matmul_nbits_mlp_fusion.cc: - Recognize the QuickGelu-decomposed shape gate_matmul -> com.microsoft::QuickGelu(alpha=1.0) -> final_mul in addition to the existing Sigmoid+Mul shape. Validates QuickGelu's `alpha == 1.0` (SiLU-equivalent). - Factor common pair validation into `ValidateMatMulNBitsPair` and keep shape-specific checks in `IsFuseCandidateSilu` and `IsFuseCandidateQuickGelu`. - Restructure the main matching loop to dispatch on which shape was found and track the intermediate nodes to remove in a small vector, so the node-removal block stays uniform across shapes. * core/providers/webgpu/math/unary_elementwise_ops.h: - Fix the `QuickGeluImpl` WGSL shader for fp16 by wrapping `1.0`, `0.0`, and `uniforms.attr` in `x_element_t(...)` casts. Without this, pipeline creation fails on fp16 models with `Invalid ShaderModule "QuickGelu"`. Matches the fix in PR #28410 so the in-tree build can run QuickGelu on fp16 models immediately rather than waiting on that PR to land. * test/optimizer/matmul_nbits_mlp_fusion_test.cc: - Add unit coverage mirroring the existing SiLU tests. Introduces an `ActivationShape` enum and parameterizes the existing test-pattern builder. The graph-shape checkers now also assert zero `com.microsoft.QuickGelu` nodes after fusion. Adds four tests: * Fusion only (Simplified-LN anchor) * Fusion only (Skip-Simplified-LN anchor) * Fused vs unfused correctness on WebGPU (Simplified-LN) * Fused vs unfused correctness on WebGPU (Skip-Simplified-LN) Correctness tests use a slightly looser 5e-3 tolerance because the by-cases sigmoid in the QuickGelu shader produces marginally different fp16 rounding than the fused kernel's direct SiLU evaluation; the two are mathematically equivalent. --- .../core/optimizer/matmul_nbits_mlp_fusion.cc | 209 +++++++++++++----- .../webgpu/math/unary_elementwise_ops.h | 6 +- .../optimizer/matmul_nbits_mlp_fusion_test.cc | 131 ++++++++++- 3 files changed, 278 insertions(+), 68 deletions(-) diff --git a/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc index 50ba18593089b..522e5f9e495cf 100644 --- a/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc @@ -15,7 +15,12 @@ namespace { constexpr const char* kActivationAttrName = "activation"; // The transformer name is generic for future expansion, but the current fused -// pattern and emitted op only support gate activation = "silu". +// pattern and emitted op only support gate activation = "silu". To add another +// gate activation (e.g. GELU for Gemma-style MLPs), extend the pattern matcher +// below to recognize the new activation subgraph (or a unary node like `Gelu`), +// add the new value to `MlpActivationKind` in matmul_nbits_mlp.h, and update +// `EmitGateActivationExpr` plus the `#if activation_kind` block in the WGSL +// template. constexpr const char* kSupportedActivation = "silu"; bool HasInput(const Node& node, size_t index) { @@ -99,6 +104,16 @@ float GetFloatAttr(const Node& node, const char* name, float default_value) { return attr == nullptr ? default_value : attr->f(); } +bool IsSupportedQuickGelu(const Node& node) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "QuickGelu", {1}, kMSDomain)) { + return false; + } + // SiLU is equivalent to QuickGelu(x, alpha=1.0). Any other alpha is a valid + // QuickGelu activation but is not the SiLU function that the fused kernel + // implements, so we conservatively reject it here. + return GetFloatAttr(node, "alpha", 1.0f) == 1.0f; +} + bool HasSingleNonGraphConsumer(const Graph& graph, const Node& node) { return !graph.NodeProducesGraphOutput(node) && optimizer_utils::CheckOutputEdges(graph, node, 1); } @@ -137,23 +152,19 @@ const Node* GetNormProducer(const Graph& graph, return gate_input; } -bool IsFuseCandidate(const Graph& graph, - const Node& gate_matmul, - const Node& up_matmul, - const Node& sigmoid, - const Node& silu_mul, - const Node& final_mul) { - if (!IsMatMulNBitsWithoutZeroPointOrGroupIdx(gate_matmul) || !IsMatMulNBitsWithoutZeroPointOrGroupIdx(up_matmul) || - !IsSupportedSigmoid(sigmoid) || !IsSupportedMul(silu_mul) || !IsSupportedMul(final_mul)) { +bool ValidateMatMulNBitsPair(const Graph& graph, + const Node& gate_matmul, + const Node& up_matmul, + size_t expected_gate_fanout) { + if (!IsMatMulNBitsWithoutZeroPointOrGroupIdx(gate_matmul) || !IsMatMulNBitsWithoutZeroPointOrGroupIdx(up_matmul)) { return false; } - if (!HasSingleNonGraphConsumer(graph, up_matmul) || !HasSingleNonGraphConsumer(graph, sigmoid) || - !HasSingleNonGraphConsumer(graph, silu_mul)) { + if (!HasSingleNonGraphConsumer(graph, up_matmul)) { return false; } - if (graph.NodeProducesGraphOutput(gate_matmul) || gate_matmul.GetOutputEdgesCount() != 2) { + if (graph.NodeProducesGraphOutput(gate_matmul) || gate_matmul.GetOutputEdgesCount() != expected_gate_fanout) { return false; } @@ -162,6 +173,44 @@ bool IsFuseCandidate(const Graph& graph, return false; } + const int64_t gate_k = GetIntAttr(gate_matmul, "K", -1, true); + const int64_t up_k = GetIntAttr(up_matmul, "K", -1, true); + const int64_t gate_n = GetIntAttr(gate_matmul, "N", -1, true); + const int64_t up_n = GetIntAttr(up_matmul, "N", -1, true); + const int64_t gate_bits = GetIntAttr(gate_matmul, "bits", 4); + const int64_t up_bits = GetIntAttr(up_matmul, "bits", 4); + const int64_t gate_block_size = GetIntAttr(gate_matmul, "block_size", -1, true); + const int64_t up_block_size = GetIntAttr(up_matmul, "block_size", -1, true); + const int64_t gate_accuracy_level = GetIntAttr(gate_matmul, "accuracy_level", 0); + const int64_t up_accuracy_level = GetIntAttr(up_matmul, "accuracy_level", 0); + + return gate_k == up_k && gate_n == up_n && + gate_bits == up_bits && gate_bits == 4 && + gate_block_size == up_block_size && gate_block_size == 32 && + gate_accuracy_level == up_accuracy_level; +} + +// Validates the SiLU-decomposed activation shape: +// gate_matmul -> Sigmoid -+ +// gate_matmul ------------+-> silu_mul -> final_mul <- up_matmul +bool IsFuseCandidateSilu(const Graph& graph, + const Node& gate_matmul, + const Node& up_matmul, + const Node& sigmoid, + const Node& silu_mul, + const Node& final_mul) { + if (!IsSupportedSigmoid(sigmoid) || !IsSupportedMul(silu_mul) || !IsSupportedMul(final_mul)) { + return false; + } + + if (!HasSingleNonGraphConsumer(graph, sigmoid) || !HasSingleNonGraphConsumer(graph, silu_mul)) { + return false; + } + + if (!ValidateMatMulNBitsPair(graph, gate_matmul, up_matmul, /*expected_gate_fanout=*/2)) { + return false; + } + if (sigmoid.InputDefs()[0] != gate_matmul.OutputDefs()[0]) { return false; } @@ -176,25 +225,36 @@ bool IsFuseCandidate(const Graph& graph, const bool final_mul_matches = (final_mul.InputDefs()[0] == silu_mul.OutputDefs()[0] && final_mul.InputDefs()[1] == up_matmul.OutputDefs()[0]) || (final_mul.InputDefs()[1] == silu_mul.OutputDefs()[0] && final_mul.InputDefs()[0] == up_matmul.OutputDefs()[0]); - if (!final_mul_matches) { + return final_mul_matches; +} + +// Validates the fused-QuickGelu activation shape produced by QuickGeluFusion: +// gate_matmul -> QuickGelu(alpha=1.0) -> final_mul <- up_matmul +bool IsFuseCandidateQuickGelu(const Graph& graph, + const Node& gate_matmul, + const Node& up_matmul, + const Node& quick_gelu, + const Node& final_mul) { + if (!IsSupportedQuickGelu(quick_gelu) || !IsSupportedMul(final_mul)) { return false; } - const int64_t gate_k = GetIntAttr(gate_matmul, "K", -1, true); - const int64_t up_k = GetIntAttr(up_matmul, "K", -1, true); - const int64_t gate_n = GetIntAttr(gate_matmul, "N", -1, true); - const int64_t up_n = GetIntAttr(up_matmul, "N", -1, true); - const int64_t gate_bits = GetIntAttr(gate_matmul, "bits", 4); - const int64_t up_bits = GetIntAttr(up_matmul, "bits", 4); - const int64_t gate_block_size = GetIntAttr(gate_matmul, "block_size", -1, true); - const int64_t up_block_size = GetIntAttr(up_matmul, "block_size", -1, true); - const int64_t gate_accuracy_level = GetIntAttr(gate_matmul, "accuracy_level", 0); - const int64_t up_accuracy_level = GetIntAttr(up_matmul, "accuracy_level", 0); + if (!HasSingleNonGraphConsumer(graph, quick_gelu)) { + return false; + } - return gate_k == up_k && gate_n == up_n && - gate_bits == up_bits && gate_bits == 4 && - gate_block_size == up_block_size && gate_block_size == 32 && - gate_accuracy_level == up_accuracy_level; + if (!ValidateMatMulNBitsPair(graph, gate_matmul, up_matmul, /*expected_gate_fanout=*/1)) { + return false; + } + + if (quick_gelu.InputDefs()[0] != gate_matmul.OutputDefs()[0]) { + return false; + } + + const bool final_mul_matches = + (final_mul.InputDefs()[0] == quick_gelu.OutputDefs()[0] && final_mul.InputDefs()[1] == up_matmul.OutputDefs()[0]) || + (final_mul.InputDefs()[1] == quick_gelu.OutputDefs()[0] && final_mul.InputDefs()[0] == up_matmul.OutputDefs()[0]); + return final_mul_matches; } } // namespace @@ -228,43 +288,71 @@ Status MatMulNBitsMlpFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l continue; } - const Node* silu_mul = nullptr; + const Node* activation_root = nullptr; const Node* up_matmul = nullptr; - if (IsSupportedMul(*input0) && IsMatMulNBitsWithoutZeroPointOrGroupIdx(*input1)) { - silu_mul = input0; + if (IsMatMulNBitsWithoutZeroPointOrGroupIdx(*input1) && + (IsSupportedMul(*input0) || IsSupportedQuickGelu(*input0))) { + activation_root = input0; up_matmul = input1; - } else if (IsSupportedMul(*input1) && IsMatMulNBitsWithoutZeroPointOrGroupIdx(*input0)) { - silu_mul = input1; + } else if (IsMatMulNBitsWithoutZeroPointOrGroupIdx(*input0) && + (IsSupportedMul(*input1) || IsSupportedQuickGelu(*input1))) { + activation_root = input1; up_matmul = input0; } else { continue; } - const Node* silu_input0 = GetInputNode(graph, *silu_mul, 0); - const Node* silu_input1 = GetInputNode(graph, *silu_mul, 1); - if (silu_input0 == nullptr || silu_input1 == nullptr) { - continue; - } - + // The gate-side subgraph between `gate_matmul` and the outer Mul `node` + // takes one of two shapes: + // 1) SiLU decomposed: gate -> Sigmoid -+ + // gate ------------+-> silu_mul -> node + // `activation_root` is the inner Mul (silu_mul); 2 intermediates. + // 2) Fused QuickGelu (post QuickGeluFusion): gate -> QuickGelu -> node + // `activation_root` is the QuickGelu node; 1 intermediate. const Node* gate_matmul = nullptr; - const Node* sigmoid = nullptr; - if (IsMatMulNBitsWithoutZeroPointOrGroupIdx(*silu_input0) && IsSupportedSigmoid(*silu_input1)) { - gate_matmul = silu_input0; - sigmoid = silu_input1; - } else if (IsMatMulNBitsWithoutZeroPointOrGroupIdx(*silu_input1) && IsSupportedSigmoid(*silu_input0)) { - gate_matmul = silu_input1; - sigmoid = silu_input0; + InlinedVector activation_intermediates; + const char* matched_shape = nullptr; + + if (IsSupportedQuickGelu(*activation_root)) { + const Node* qg_input = GetInputNode(graph, *activation_root, 0); + if (qg_input == nullptr || !IsMatMulNBitsWithoutZeroPointOrGroupIdx(*qg_input)) { + continue; + } + gate_matmul = qg_input; + if (!IsFuseCandidateQuickGelu(graph, *gate_matmul, *up_matmul, *activation_root, node)) { + continue; + } + activation_intermediates.push_back(activation_root); + matched_shape = "quick_gelu"; } else { - continue; - } + const Node* silu_input0 = GetInputNode(graph, *activation_root, 0); + const Node* silu_input1 = GetInputNode(graph, *activation_root, 1); + if (silu_input0 == nullptr || silu_input1 == nullptr) { + continue; + } - if (!IsFuseCandidate(graph, *gate_matmul, *up_matmul, *sigmoid, *silu_mul, node)) { - continue; + const Node* sigmoid = nullptr; + if (IsMatMulNBitsWithoutZeroPointOrGroupIdx(*silu_input0) && IsSupportedSigmoid(*silu_input1)) { + gate_matmul = silu_input0; + sigmoid = silu_input1; + } else if (IsMatMulNBitsWithoutZeroPointOrGroupIdx(*silu_input1) && IsSupportedSigmoid(*silu_input0)) { + gate_matmul = silu_input1; + sigmoid = silu_input0; + } else { + continue; + } + + if (!IsFuseCandidateSilu(graph, *gate_matmul, *up_matmul, *sigmoid, *activation_root, node)) { + continue; + } + activation_intermediates.push_back(sigmoid); + activation_intermediates.push_back(activation_root); + matched_shape = "silu"; } - LOGS(logger, VERBOSE) << "MatMulNBitsMlpFusion: matched candidate output_mul='" << node.Name() + LOGS(logger, VERBOSE) << "MatMulNBitsMlpFusion: matched candidate shape='" << matched_shape + << "' output_mul='" << node.Name() << "' gate='" << gate_matmul->Name() << "' up='" << up_matmul->Name() - << "' sigmoid='" << sigmoid->Name() << "' activation_mul='" << silu_mul->Name() << "' attrs={K=" << GetIntAttr(*gate_matmul, "K", -1, true) << ", N=" << GetIntAttr(*gate_matmul, "N", -1, true) << ", bits=" << GetIntAttr(*gate_matmul, "bits", 4) @@ -272,10 +360,17 @@ Status MatMulNBitsMlpFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l << ", accuracy_level=" << GetIntAttr(*gate_matmul, "accuracy_level", 0) << "}"; + bool intermediates_on_supported_ep = true; + for (const Node* intermediate : activation_intermediates) { + const auto& ep = intermediate->GetExecutionProviderType(); + if (!ep.empty() && ep != kWebGpuExecutionProvider) { + intermediates_on_supported_ep = false; + break; + } + } if ((!gate_matmul->GetExecutionProviderType().empty() && gate_matmul->GetExecutionProviderType() != kWebGpuExecutionProvider) || (!up_matmul->GetExecutionProviderType().empty() && up_matmul->GetExecutionProviderType() != kWebGpuExecutionProvider) || - (!sigmoid->GetExecutionProviderType().empty() && sigmoid->GetExecutionProviderType() != kWebGpuExecutionProvider) || - (!silu_mul->GetExecutionProviderType().empty() && silu_mul->GetExecutionProviderType() != kWebGpuExecutionProvider)) { + !intermediates_on_supported_ep) { LOGS(logger, VERBOSE) << "MatMulNBitsMlpFusion: skipping candidate due to non-WebGPU EP assignment."; continue; } @@ -334,10 +429,10 @@ Status MatMulNBitsMlpFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l graph.RemoveNode(gate_matmul->Index()); graph_utils::RemoveNodeOutputEdges(graph, const_cast(*up_matmul)); graph.RemoveNode(up_matmul->Index()); - graph_utils::RemoveNodeOutputEdges(graph, const_cast(*sigmoid)); - graph.RemoveNode(sigmoid->Index()); - graph_utils::RemoveNodeOutputEdges(graph, const_cast(*silu_mul)); - graph.RemoveNode(silu_mul->Index()); + for (const Node* intermediate : activation_intermediates) { + graph_utils::RemoveNodeOutputEdges(graph, const_cast(*intermediate)); + graph.RemoveNode(intermediate->Index()); + } graph_utils::RemoveNodeOutputEdges(graph, node); graph.RemoveNode(node.Index()); diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h index 3285f1e6065bb..636f185eda422 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -136,9 +136,9 @@ fn elu_v(v: vec4) -> vec4 { constexpr const char QuickGeluImpl[] = R"( fn quick_gelu_v(a: vec4) -> vec4 { - let one = 1.0; - let zero = 0.0; - let alpha_vec = vec4(uniforms.attr); + let one = x_element_t(1.0); + let zero = x_element_t(0.0); + let alpha_vec = vec4(x_element_t(uniforms.attr)); let v = a * alpha_vec; var x1 : vec4; for (var i = 0; i < 4; i = i + 1) { diff --git a/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc b/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc index 01ffbc1c71669..c62400c1e8346 100644 --- a/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc +++ b/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc @@ -41,6 +41,16 @@ enum class BiasKind { kNoBias, }; +// Selects the gate-activation subgraph emitted by the test builder. +// kSilu : gate -> Sigmoid -+ +// gate ------------+-> Mul -> final_mul +// kQuickGelu : gate -> com.microsoft::QuickGelu(alpha=1.0) -> final_mul +// (the shape QuickGeluFusion produces after PR #28410.) +enum class ActivationShape { + kSilu, + kQuickGelu, +}; + void SetWebGpuProvider(Node& node) { node.SetExecutionProviderType(kWebGpuExecutionProvider); } @@ -62,6 +72,7 @@ Status CheckMatMulNBitsMlpFusedGraphImpl(const Graph& graph, NormAnchorKind norm OpCount(op_to_count, "SimplifiedLayerNormalization") != 0 || OpCount(op_to_count, "com.microsoft.SkipSimplifiedLayerNormalization") != 0 || OpCount(op_to_count, "Sigmoid") != 0 || + OpCount(op_to_count, "com.microsoft.QuickGelu") != 0 || OpCount(op_to_count, "Mul") != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected operator counts after MatMulNBitsMlpFusion."); } @@ -105,6 +116,7 @@ Status CheckMatMulNBitsMlpSkipOutputPassthroughFusedGraph(const Graph& graph) { OpCount(op_to_count, "SimplifiedLayerNormalization") != 0 || OpCount(op_to_count, "com.microsoft.SkipSimplifiedLayerNormalization") != 0 || OpCount(op_to_count, "Sigmoid") != 0 || + OpCount(op_to_count, "com.microsoft.QuickGelu") != 0 || OpCount(op_to_count, "Mul") != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected operator counts after MatMulNBitsMlpFusion with skip output passthrough."); @@ -142,7 +154,8 @@ Status CheckMatMulNBitsMlpSkipOutputPassthroughFusedGraph(const Graph& graph) { void BuildMatMulNBitsMlpWebGpuPatternImpl(ModelTestBuilder& builder, NormAnchorKind norm_anchor_kind, SkipOutputKind skip_output_kind = SkipOutputKind::kNone, - BiasKind bias_kind = BiasKind::kWithBias) { + BiasKind bias_kind = BiasKind::kWithBias, + ActivationShape activation_shape = ActivationShape::kSilu) { constexpr int64_t k = 32; constexpr int64_t n = 8; constexpr int64_t block_size = 32; @@ -177,8 +190,7 @@ void BuildMatMulNBitsMlpWebGpuPatternImpl(ModelTestBuilder& builder, NodeArg* normalized_input = builder.MakeIntermediate(std::vector{1, k}); NodeArg* gate_out = builder.MakeIntermediate(std::vector{1, n}); NodeArg* up_out = builder.MakeIntermediate(std::vector{1, n}); - NodeArg* sigmoid_out = builder.MakeIntermediate(std::vector{1, n}); - NodeArg* silu_out = builder.MakeIntermediate(std::vector{1, n}); + NodeArg* activated_out = builder.MakeIntermediate(std::vector{1, n}); NodeArg* output = builder.MakeOutput(std::vector{1, n}); NodeAttributes matmul_attrs = MakeMatMulNBitsAttrs(k, n, block_size, bits, accuracy_level); @@ -212,17 +224,35 @@ void BuildMatMulNBitsMlpWebGpuPatternImpl(ModelTestBuilder& builder, {normalized_input, up_weight, up_scale, optional_tensor, optional_tensor, up_bias}, {up_out}, kMSDomain, &matmul_attrs); - Node& sigmoid = builder.AddNode("Sigmoid", {gate_out}, {sigmoid_out}); - Node& silu_mul = builder.AddNode("Mul", {gate_out, sigmoid_out}, {silu_out}); - Node& final_mul = builder.AddNode("Mul", {silu_out, up_out}, {output}); + + Node* sigmoid = nullptr; + Node* silu_mul = nullptr; + Node* quick_gelu = nullptr; + if (activation_shape == ActivationShape::kSilu) { + NodeArg* sigmoid_out = builder.MakeIntermediate(std::vector{1, n}); + sigmoid = &builder.AddNode("Sigmoid", {gate_out}, {sigmoid_out}); + silu_mul = &builder.AddNode("Mul", {gate_out, sigmoid_out}, {activated_out}); + } else { + NodeAttributes quick_gelu_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("alpha", 1.0f), quick_gelu_attrs); + quick_gelu = &builder.AddNode("QuickGelu", {gate_out}, {activated_out}, kMSDomain, &quick_gelu_attrs); + } + Node& final_mul = builder.AddNode("Mul", {activated_out, up_out}, {output}); if (norm != nullptr) { SetWebGpuProvider(*norm); } SetWebGpuProvider(gate_matmul); SetWebGpuProvider(up_matmul); - SetWebGpuProvider(sigmoid); - SetWebGpuProvider(silu_mul); + if (sigmoid != nullptr) { + SetWebGpuProvider(*sigmoid); + } + if (silu_mul != nullptr) { + SetWebGpuProvider(*silu_mul); + } + if (quick_gelu != nullptr) { + SetWebGpuProvider(*quick_gelu); + } SetWebGpuProvider(final_mul); } @@ -253,6 +283,16 @@ void BuildMatMulNBitsMlpSkipOutputPassthroughWebGpuPatternNoBias(ModelTestBuilde BiasKind::kNoBias); } +void BuildMatMulNBitsMlpSimplifiedQuickGeluWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsMlpWebGpuPatternImpl(builder, NormAnchorKind::kSimplified, SkipOutputKind::kNone, + BiasKind::kWithBias, ActivationShape::kQuickGelu); +} + +void BuildMatMulNBitsMlpSkipQuickGeluWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsMlpWebGpuPatternImpl(builder, NormAnchorKind::kSkipSimplified, SkipOutputKind::kNone, + BiasKind::kWithBias, ActivationShape::kQuickGelu); +} + } // namespace TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionFusesSimplifiedWebGpuPattern) { @@ -429,6 +469,81 @@ TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuRes add_session_options); } +// QuickGelu-shape tests: after PR #28410, QuickGeluFusion collapses the +// Sigmoid+Mul subgraph in SwiGLU MLPs into a single com.microsoft::QuickGelu +// node (with alpha=1.0). MatMulNBitsMlpFusion must still recognize this shape +// so the fused MLP kernel keeps firing on WebGPU models. +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionFusesSimplifiedQuickGeluWebGpuPattern) { + ASSERT_STATUS_OK(TestGraphTransformer( + BuildMatMulNBitsMlpSimplifiedQuickGeluWebGpuPattern, + 21, + *logger_, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + TransformerLevel::Level2, + 1, + nullptr, + CheckMatMulNBitsMlpSimplifiedFusedGraph)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionFusesSkipQuickGeluWebGpuPattern) { + ASSERT_STATUS_OK(TestGraphTransformer( + BuildMatMulNBitsMlpSkipQuickGeluWebGpuPattern, + 21, + *logger_, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + TransformerLevel::Level2, + 1, + nullptr, + CheckMatMulNBitsMlpSkipFusedGraph)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSimplifiedQuickGeluWebGpuResults) { + if (!DefaultWebGpuExecutionProvider()) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsMlpSimplifiedFusedGraph(session.GetGraph())); + }; + + // The unfused baseline runs the WebGPU `QuickGelu` kernel (a branchy + // sigmoid-by-cases implementation), while the fused kernel evaluates SiLU + // directly via `1 / (1 + exp(-x))`. The two decompositions are + // mathematically equivalent but produce slightly different fp16 rounding + // around the SiLU midpoint, so we use a marginally looser tolerance here. + RunWebGpuFusionTransformerTest( + BuildMatMulNBitsMlpSimplifiedQuickGeluWebGpuPattern, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 5e-3, + 5e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + []() { return DefaultWebGpuExecutionProvider(); }); +} + +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipQuickGeluWebGpuResults) { + if (!DefaultWebGpuExecutionProvider()) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsMlpSkipFusedGraph(session.GetGraph())); + }; + + RunWebGpuFusionTransformerTest( + BuildMatMulNBitsMlpSkipQuickGeluWebGpuPattern, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 5e-3, + 5e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + []() { return DefaultWebGpuExecutionProvider(); }); +} + #endif // !defined(DISABLE_CONTRIB_OPS) } // namespace test From 234bcf4473e60c970f9fe2530f5f7214225c79f2 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 11 May 2026 15:37:43 -0700 Subject: [PATCH 25/26] [WebGPU/JSEP] Enable QuickGeluFusion for WebGPU and JSEP EPs Widen `QuickGeluFusion`'s compatible-EP set from `cpu_acl_cuda_dml_eps` to `cpu_acl_cuda_dml_js_webgpu_eps` so the `x * Sigmoid(x)` SwiGLU gate pattern is folded into a single `com.microsoft::QuickGelu` node on WebGPU and JSEP models. Without this, the QuickGelu match branch added to `MatMulNBitsMlpFusion` in the prior commit is unreachable on real WebGPU models, and the `QuickGelu` fp16 shader fix in `unary_elementwise_ops.h` cannot be exercised end-to-end. Mirrors upstream PR #28410 (registers `QuickGeluFusion` for WebGPU/JSEP and fixes the `QuickGelu` fp16 shader). This commit is expected to be redundant once #28410 lands; rebase will drop it cleanly. --- onnxruntime/core/optimizer/graph_transformer_utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 36f3cdda81dc6..a4311969a6d73 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -413,7 +413,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_js_webgpu_eps)); transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_js_webgpu_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_eps)); - transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_js_webgpu_eps)); // GeluApproximation has side effects which may change results. It needs to be manually enabled, // or alternatively the model can be updated offline using a model conversion script From eaa6635c7c2342a3b41bd10765ba6c082d8fed84 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 11 May 2026 19:51:55 -0700 Subject: [PATCH 26/26] Copilot comments --- .../matmul_nbits_mlp.wgsl.template | 11 ++++------ .../matmul_nbits_qkv.wgsl.template | 12 +++++----- .../core/optimizer/matmul_nbits_qkv_fusion.cc | 22 ++++++++++++++++++- .../optimizer/matmul_nbits_mlp_fusion_test.cc | 4 ++-- .../test/optimizer/webgpu_fusion_test_util.h | 1 + 5 files changed, 34 insertions(+), 16 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template index 88d53e04177a8..f64f0d38f24e2 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template @@ -65,7 +65,10 @@ fn loadSHMA(batch: u32, b_global_base: u32, kidx: u32, col: u32, inv_std: f32) } fn compute_gate_up_sums(b_global: u32, kidx: u32, idx: u32, k_offset: u32) -> vec2 { -#if !single_scale_weights +#if single_scale_weights + let gate_scale_b = gate_scales_b.getByOffset(0); + let up_scale_b = up_scales_b.getByOffset(0); +#else let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size; let gate_scale_b = gate_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx); let up_scale_b = up_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx); @@ -203,12 +206,6 @@ $MAIN { let inv_std = 1.0; #endif -#if single_scale_weights - let gate_scale_b = gate_scales_b.getByOffset(0); - let up_scale_b = up_scales_b.getByOffset(0); - let block_idx = 0u; -#endif - #if k_unroll_tiles == 1 for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k) { process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template index 61b50ada36cfa..60f34e9ef2530 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template @@ -149,6 +149,12 @@ fn process_k_tile(batch: u32, b_global_base: u32, thread_idx: u32, idx: u32, idy } workgroupBarrier(); +#if single_scale_weights + let q_scale_b = q_output_element_t(q_scales_b.getByOffset(0)); + let k_scale_b = q_output_element_t(k_scales_b.getByOffset(0)); + let v_scale_b = q_output_element_t(v_scales_b.getByOffset(0)); +#endif + for (var local_row_offset = 0u; local_row_offset < tile_size; local_row_offset += sub_tile_count) { let b_global = b_global_base + local_row_offset + idy; let k_offset = kidx / elements_in_value_b + idx; @@ -226,12 +232,6 @@ $MAIN { let inv_std = 1.0; #endif -#if single_scale_weights - let q_scale_b = q_output_element_t(q_scales_b.getByOffset(0)); - let k_scale_b = q_output_element_t(k_scales_b.getByOffset(0)); - let v_scale_b = q_output_element_t(v_scales_b.getByOffset(0)); -#endif - #if k_unroll_tiles == 1 for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k) { process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); diff --git a/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc index 840d1e99ba1bd..05cd234dba577 100644 --- a/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc @@ -4,6 +4,7 @@ #include "core/optimizer/matmul_nbits_qkv_fusion.h" #include +#include #include #include "core/graph/graph_utils.h" @@ -73,15 +74,34 @@ bool IsGraphOutput(const Graph& graph, const Node& node, size_t index) { return false; } +bool HasOutputConsumers(const Node& node, size_t index) { + if (!HasProducedOutput(node, index)) { + return false; + } + for (auto edge_it = node.OutputEdgesBegin(); edge_it != node.OutputEdgesEnd(); ++edge_it) { + if (static_cast(edge_it->GetSrcArgIndex()) == index) { + return true; + } + } + return false; +} + // Output 0 of the norm is consumed by the fused op, so it must not be a graph output. // For SkipSimplifiedLayerNormalization the optional residual sum at output 3 is // preserved by the fused MatMulNBitsQkv op, so it is allowed to remain a graph output. -// Outputs 1 and 2 (mean / inv_std_var) are not exposed by the fused op. +// Outputs 1 and 2 (mean / inv_std_var) are not exposed by the fused op and must not +// be graph outputs or feed any downstream nodes. bool IsSupportedNormGraphOutputsForFusion(const Graph& graph, const Node& norm) { if (IsGraphOutput(graph, norm, 0)) { return false; } for (size_t i = 1; i < norm.OutputDefs().size(); ++i) { + if (i == 1 || i == 2) { + if (IsGraphOutput(graph, norm, i) || HasOutputConsumers(norm, i)) { + return false; + } + continue; + } if (!IsGraphOutput(graph, norm, i)) { continue; } diff --git a/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc b/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc index c62400c1e8346..038876b6c5777 100644 --- a/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc +++ b/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc @@ -517,8 +517,8 @@ TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSimplifiedQui TransformerLevel::Level1, TransformerLevel::Level2, 21, - 5e-3, - 5e-3, + 1.5e-2, + 1.5e-2, std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), []() { return DefaultWebGpuExecutionProvider(); }); } diff --git a/onnxruntime/test/optimizer/webgpu_fusion_test_util.h b/onnxruntime/test/optimizer/webgpu_fusion_test_util.h index 5cf0d63dc710b..2fb3344bb9313 100644 --- a/onnxruntime/test/optimizer/webgpu_fusion_test_util.h +++ b/onnxruntime/test/optimizer/webgpu_fusion_test_util.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include