Skip to content

Commit 1a2b062

Browse files
committed
feat: auto-calculate max_new_tokens to align with vLLM behavior
When max_new_tokens is not specified (None or -1), automatically calculate it as max_req_total_len - prompt_tokens. This aligns with vLLM's behavior where max_tokens defaults to the remaining context length. Changes: - sampling_params.py: default max_new_tokens changed from 16384 to -1 - py_sampling_params.py: default max_new_tokens changed from 16384 to None - manager.py: add auto-calculation logic in _check_and_repair_length
1 parent 391d2ea commit 1a2b062

3 files changed

Lines changed: 23 additions & 6 deletions

File tree

lightllm/server/core/objs/py_sampling_params.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
top_k: int = None, # -1 is for all
3939
ignore_eos: bool = False,
4040
image_max_patch_num: int = -1,
41-
max_new_tokens: int = 16384,
41+
max_new_tokens: int = None, # If None, auto-calculate as max_req_total_len - input_len
4242
min_new_tokens: int = 1,
4343
stop_sequences: Optional[Union[str, List[str], List[List[int]]]] = None, # 停止句子条件
4444
skip_special_tokens: bool = True, # whether to skip special tokens when decoding
@@ -141,11 +141,11 @@ def verify(self):
141141
raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}")
142142
if self.top_k < -1 or self.top_k == 0:
143143
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
144-
if self.max_new_tokens < 1:
144+
if self.max_new_tokens is not None and self.max_new_tokens < 1:
145145
raise ValueError(f"max_new_tokens must be at least 1, got {self.max_new_tokens}.")
146146
if self.min_new_tokens < 1:
147147
raise ValueError(f"min_new_tokens must be at least 1, got {self.min_new_tokens}.")
148-
if self.min_new_tokens > self.max_new_tokens:
148+
if self.max_new_tokens is not None and self.min_new_tokens > self.max_new_tokens:
149149
raise ValueError(
150150
f"min_new_tokens must <= max_new_tokens, but got min {self.min_new_tokens}, max {self.max_new_tokens}."
151151
)

lightllm/server/core/objs/sampling_params.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,9 @@ def init(self, tokenizer, **kwargs):
345345
self.top_k = kwargs.get("top_k", SamplingParams._top_k)
346346
self.ignore_eos = kwargs.get("ignore_eos", False)
347347
self.image_max_patch_num = kwargs.get("image_max_patch_num", -1)
348-
self.max_new_tokens = kwargs.get("max_new_tokens", 16384)
348+
self.max_new_tokens = kwargs.get(
349+
"max_new_tokens", -1
350+
) # -1 means auto-calculate as max_req_total_len - input_len
349351
self.min_new_tokens = kwargs.get("min_new_tokens", 1)
350352
self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY)
351353
self.group_request_id = kwargs.get("group_request_id", -1)
@@ -439,11 +441,11 @@ def verify(self):
439441
raise ValueError(f"top_p must be in (0.0, 1.0], got {self.top_p}")
440442
if self.top_k < -1 or self.top_k == 0:
441443
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
442-
if self.max_new_tokens < 1:
444+
if self.max_new_tokens != -1 and self.max_new_tokens < 1:
443445
raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.")
444446
if self.min_new_tokens < 1:
445447
raise ValueError(f"min_new_tokens must be at least 1 , got {self.min_new_tokens}.")
446-
if self.min_new_tokens > self.max_new_tokens:
448+
if self.max_new_tokens != -1 and self.min_new_tokens > self.max_new_tokens:
447449
raise ValueError(
448450
f"min_new_tokens must <= max_new_tokens, but got min {self.min_new_tokens}, max {self.max_new_tokens}."
449451
)

lightllm/server/httpserver/manager.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,21 @@ async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params:
477477
if not prompt_ids:
478478
raise ValueError("prompt_ids is empty")
479479
prompt_tokens = len(prompt_ids)
480+
481+
# If max_new_tokens is None or -1, auto-calculate based on model context length (align with vLLM behavior)
482+
# -1 is used as sentinel for ctypes-based SamplingParams, None for pure Python SamplingParams
483+
if sampling_params.max_new_tokens is None or sampling_params.max_new_tokens == -1:
484+
sampling_params.max_new_tokens = self.max_req_total_len - prompt_tokens
485+
if sampling_params.max_new_tokens < 1:
486+
raise ValueError(
487+
f"the input prompt token len {prompt_tokens} >= max_req_total_len {self.max_req_total_len}, "
488+
f"no space for output tokens"
489+
)
490+
logger.debug(
491+
f"max_new_tokens is unset, auto-calculate to {sampling_params.max_new_tokens} "
492+
f"(max_req_total_len {self.max_req_total_len} - prompt_tokens {prompt_tokens})"
493+
)
494+
480495
if prompt_tokens + sampling_params.max_new_tokens > self.max_req_total_len:
481496
# use long_truncation_mode to truncate long input len req.
482497
if self.args.long_truncation_mode is None:

0 commit comments

Comments
 (0)