11import math
22import uuid
33from typing import Any , List , Optional
4+ import asyncio
45
56from graphgen .bases .base_llm_wrapper import BaseLLMWrapper
67from graphgen .bases .datatypes import Token
@@ -19,12 +20,9 @@ def __init__(
1920 temperature : float = 0.6 ,
2021 top_p : float = 1.0 ,
2122 top_k : int = 5 ,
23+ timeout : float = 300 ,
2224 ** kwargs : Any ,
2325 ):
24- temperature = float (temperature )
25- top_p = float (top_p )
26- top_k = int (top_k )
27-
2826 super ().__init__ (temperature = temperature , top_p = top_p , top_k = top_k , ** kwargs )
2927 try :
3028 from vllm import AsyncEngineArgs , AsyncLLMEngine , SamplingParams
@@ -43,6 +41,7 @@ def __init__(
4341 disable_log_stats = False ,
4442 )
4543 self .engine = AsyncLLMEngine .from_engine_args (engine_args )
44+ self .timeout = float (timeout )
4645
4746 @staticmethod
4847 def _build_inputs (prompt : str , history : Optional [List [str ]] = None ) -> str :
@@ -58,6 +57,12 @@ def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
5857 lines .append (prompt )
5958 return "\n " .join (lines )
6059
60+ async def _consume_generator (self , generator ):
61+ final_output = None
62+ async for request_output in generator :
63+ final_output = request_output
64+ return final_output
65+
6166 async def generate_answer (
6267 self , text : str , history : Optional [List [str ]] = None , ** extra : Any
6368 ) -> str :
@@ -72,14 +77,21 @@ async def generate_answer(
7277
7378 result_generator = self .engine .generate (full_prompt , sp , request_id = request_id )
7479
75- final_output = None
76- async for request_output in result_generator :
77- final_output = request_output
80+ try :
81+ final_output = await asyncio .wait_for (
82+ self ._consume_generator (result_generator ),
83+ timeout = self .timeout
84+ )
85+
86+ if not final_output or not final_output .outputs :
87+ return ""
7888
79- if not final_output or not final_output .outputs :
80- return ""
89+ result_text = final_output .outputs [ 0 ]. text
90+ return result_text
8191
82- return final_output .outputs [0 ].text
92+ except (Exception , asyncio .CancelledError ):
93+ await self .engine .abort (request_id )
94+ raise
8395
8496 async def generate_topk_per_token (
8597 self , text : str , history : Optional [List [str ]] = None , ** extra : Any
@@ -91,42 +103,47 @@ async def generate_topk_per_token(
91103 temperature = 0 ,
92104 max_tokens = 1 ,
93105 logprobs = self .top_k ,
94- prompt_logprobs = 1 ,
95106 )
96107
97108 result_generator = self .engine .generate (full_prompt , sp , request_id = request_id )
98109
99- final_output = None
100- async for request_output in result_generator :
101- final_output = request_output
102-
103- if (
104- not final_output
105- or not final_output .outputs
106- or not final_output .outputs [0 ].logprobs
107- ):
108- return []
109-
110- top_logprobs = final_output .outputs [0 ].logprobs [0 ]
111-
112- candidate_tokens = []
113- for _ , logprob_obj in top_logprobs .items ():
114- tok_str = (
115- logprob_obj .decoded_token .strip () if logprob_obj .decoded_token else ""
110+ try :
111+ final_output = await asyncio .wait_for (
112+ self ._consume_generator (result_generator ),
113+ timeout = self .timeout
116114 )
117- prob = float (math .exp (logprob_obj .logprob ))
118- candidate_tokens .append (Token (tok_str , prob ))
119115
120- candidate_tokens .sort (key = lambda x : - x .prob )
116+ if (
117+ not final_output
118+ or not final_output .outputs
119+ or not final_output .outputs [0 ].logprobs
120+ ):
121+ return []
122+
123+ top_logprobs = final_output .outputs [0 ].logprobs [0 ]
124+
125+ candidate_tokens = []
126+ for _ , logprob_obj in top_logprobs .items ():
127+ tok_str = (
128+ logprob_obj .decoded_token .strip () if logprob_obj .decoded_token else ""
129+ )
130+ prob = float (math .exp (logprob_obj .logprob ))
131+ candidate_tokens .append (Token (tok_str , prob ))
132+
133+ candidate_tokens .sort (key = lambda x : - x .prob )
134+
135+ if candidate_tokens :
136+ main_token = Token (
137+ text = candidate_tokens [0 ].text ,
138+ prob = candidate_tokens [0 ].prob ,
139+ top_candidates = candidate_tokens ,
140+ )
141+ return [main_token ]
142+ return []
121143
122- if candidate_tokens :
123- main_token = Token (
124- text = candidate_tokens [0 ].text ,
125- prob = candidate_tokens [0 ].prob ,
126- top_candidates = candidate_tokens ,
127- )
128- return [main_token ]
129- return []
144+ except (Exception , asyncio .CancelledError ):
145+ await self .engine .abort (request_id )
146+ raise
130147
131148 async def generate_inputs_prob (
132149 self , text : str , history : Optional [List [str ]] = None , ** extra : Any
0 commit comments