diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index dd29c9a833..5e8301015c 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -33,7 +33,7 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192): graph_split_batch_size = self.args.graph_split_batch_size * (self.mtp_step + 1) graph_grow_step_size = self.args.graph_grow_step_size * (self.mtp_step + 1) - batch_sizes = [i * (self.mtp_step + 1) for i in range(1, graph_split_batch_size + 1)] + batch_sizes = [i * (self.mtp_step + 1) for i in range(1, self.args.graph_split_batch_size + 1)] for _batch_size in range(graph_split_batch_size + graph_grow_step_size, max_batch_size, graph_grow_step_size): batch_sizes.append(_batch_size)