From 03224283b7778f843575306351b9aa115587242c Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Thu, 2 Apr 2026 15:49:12 +0800 Subject: [PATCH] fix mtp cuda graph init. --- lightllm/common/basemodel/cuda_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index dd29c9a833..5e8301015c 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -33,7 +33,7 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192): graph_split_batch_size = self.args.graph_split_batch_size * (self.mtp_step + 1) graph_grow_step_size = self.args.graph_grow_step_size * (self.mtp_step + 1) - batch_sizes = [i * (self.mtp_step + 1) for i in range(1, graph_split_batch_size + 1)] + batch_sizes = [i * (self.mtp_step + 1) for i in range(1, self.args.graph_split_batch_size + 1)] for _batch_size in range(graph_split_batch_size + graph_grow_step_size, max_batch_size, graph_grow_step_size): batch_sizes.append(_batch_size)