Skip to content

Commit 14616c0

Browse files
fix: fix timeout error in vllmwrapper (#143)
* fix: fix timeout error in vllmwrapper * fix: delete useless prompt_logprobs=1 * fix: catch CancellEderror
1 parent fc84539 commit 14616c0

2 files changed

Lines changed: 61 additions & 44 deletions

File tree

graphgen/bases/base_llm_wrapper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ def __init__(
2626
**kwargs: Any,
2727
):
2828
self.system_prompt = system_prompt
29-
self.temperature = temperature
30-
self.max_tokens = max_tokens
31-
self.repetition_penalty = repetition_penalty
32-
self.top_p = top_p
33-
self.top_k = top_k
29+
self.temperature = float(temperature)
30+
self.max_tokens = int(max_tokens)
31+
self.repetition_penalty = float(repetition_penalty)
32+
self.top_p = float(top_p)
33+
self.top_k = int(top_k)
3434
self.tokenizer = tokenizer
3535

3636
for k, v in kwargs.items():

graphgen/models/llm/local/vllm_wrapper.py

Lines changed: 56 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
import uuid
33
from typing import Any, List, Optional
4+
import asyncio
45

56
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
67
from graphgen.bases.datatypes import Token
@@ -19,12 +20,9 @@ def __init__(
1920
temperature: float = 0.6,
2021
top_p: float = 1.0,
2122
top_k: int = 5,
23+
timeout: float = 300,
2224
**kwargs: Any,
2325
):
24-
temperature = float(temperature)
25-
top_p = float(top_p)
26-
top_k = int(top_k)
27-
2826
super().__init__(temperature=temperature, top_p=top_p, top_k=top_k, **kwargs)
2927
try:
3028
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
@@ -43,6 +41,7 @@ def __init__(
4341
disable_log_stats=False,
4442
)
4543
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
44+
self.timeout = float(timeout)
4645

4746
@staticmethod
4847
def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
@@ -58,6 +57,12 @@ def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
5857
lines.append(prompt)
5958
return "\n".join(lines)
6059

60+
async def _consume_generator(self, generator):
61+
final_output = None
62+
async for request_output in generator:
63+
final_output = request_output
64+
return final_output
65+
6166
async def generate_answer(
6267
self, text: str, history: Optional[List[str]] = None, **extra: Any
6368
) -> str:
@@ -72,14 +77,21 @@ async def generate_answer(
7277

7378
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
7479

75-
final_output = None
76-
async for request_output in result_generator:
77-
final_output = request_output
80+
try:
81+
final_output = await asyncio.wait_for(
82+
self._consume_generator(result_generator),
83+
timeout=self.timeout
84+
)
85+
86+
if not final_output or not final_output.outputs:
87+
return ""
7888

79-
if not final_output or not final_output.outputs:
80-
return ""
89+
result_text = final_output.outputs[0].text
90+
return result_text
8191

82-
return final_output.outputs[0].text
92+
except (Exception, asyncio.CancelledError):
93+
await self.engine.abort(request_id)
94+
raise
8395

8496
async def generate_topk_per_token(
8597
self, text: str, history: Optional[List[str]] = None, **extra: Any
@@ -91,42 +103,47 @@ async def generate_topk_per_token(
91103
temperature=0,
92104
max_tokens=1,
93105
logprobs=self.top_k,
94-
prompt_logprobs=1,
95106
)
96107

97108
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
98109

99-
final_output = None
100-
async for request_output in result_generator:
101-
final_output = request_output
102-
103-
if (
104-
not final_output
105-
or not final_output.outputs
106-
or not final_output.outputs[0].logprobs
107-
):
108-
return []
109-
110-
top_logprobs = final_output.outputs[0].logprobs[0]
111-
112-
candidate_tokens = []
113-
for _, logprob_obj in top_logprobs.items():
114-
tok_str = (
115-
logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else ""
110+
try:
111+
final_output = await asyncio.wait_for(
112+
self._consume_generator(result_generator),
113+
timeout=self.timeout
116114
)
117-
prob = float(math.exp(logprob_obj.logprob))
118-
candidate_tokens.append(Token(tok_str, prob))
119115

120-
candidate_tokens.sort(key=lambda x: -x.prob)
116+
if (
117+
not final_output
118+
or not final_output.outputs
119+
or not final_output.outputs[0].logprobs
120+
):
121+
return []
122+
123+
top_logprobs = final_output.outputs[0].logprobs[0]
124+
125+
candidate_tokens = []
126+
for _, logprob_obj in top_logprobs.items():
127+
tok_str = (
128+
logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else ""
129+
)
130+
prob = float(math.exp(logprob_obj.logprob))
131+
candidate_tokens.append(Token(tok_str, prob))
132+
133+
candidate_tokens.sort(key=lambda x: -x.prob)
134+
135+
if candidate_tokens:
136+
main_token = Token(
137+
text=candidate_tokens[0].text,
138+
prob=candidate_tokens[0].prob,
139+
top_candidates=candidate_tokens,
140+
)
141+
return [main_token]
142+
return []
121143

122-
if candidate_tokens:
123-
main_token = Token(
124-
text=candidate_tokens[0].text,
125-
prob=candidate_tokens[0].prob,
126-
top_candidates=candidate_tokens,
127-
)
128-
return [main_token]
129-
return []
144+
except (Exception, asyncio.CancelledError):
145+
await self.engine.abort(request_id)
146+
raise
130147

131148
async def generate_inputs_prob(
132149
self, text: str, history: Optional[List[str]] = None, **extra: Any

0 commit comments

Comments
 (0)