@@ -20,7 +20,7 @@ def __init__(
2020 temperature : float = 0.6 ,
2121 top_p : float = 1.0 ,
2222 top_k : int = 5 ,
23- timeout : float = 300 ,
23+ timeout : float = 600 ,
2424 ** kwargs : Any ,
2525 ):
2626 super ().__init__ (temperature = temperature , top_p = top_p , top_k = top_k , ** kwargs )
@@ -42,25 +42,24 @@ def __init__(
4242 )
4343 self .engine = AsyncLLMEngine .from_engine_args (engine_args )
4444 self .timeout = float (timeout )
45+ self .tokenizer = self .engine .engine .tokenizer .tokenizer
4546
46- @staticmethod
47- def _build_inputs (prompt : str , history : Optional [List [str ]] = None ) -> str :
48- msgs = history or []
49- lines = []
50- for m in msgs :
51- if isinstance (m , dict ):
52- role = m .get ("role" , "" )
53- content = m .get ("content" , "" )
54- lines .append (f"{ role } : { content } " )
55- else :
56- lines .append (str (m ))
57- lines .append (prompt )
58- return "\n " .join (lines )
47+ def _build_inputs (self , prompt : str , history : Optional [List [dict ]] = None ) -> Any :
48+ messages = history or []
49+ messages .append ({"role" : "user" , "content" : prompt })
50+
51+ return self .tokenizer .apply_chat_template (
52+ messages ,
53+ tokenize = False ,
54+ add_generation_prompt = True
55+ )
5956
6057 async def _consume_generator (self , generator ):
6158 final_output = None
6259 async for request_output in generator :
63- final_output = request_output
60+ if request_output .finished :
61+ final_output = request_output
62+ break
6463 return final_output
6564
6665 async def generate_answer (
@@ -70,14 +69,14 @@ async def generate_answer(
7069 request_id = f"graphgen_req_{ uuid .uuid4 ()} "
7170
7271 sp = self .SamplingParams (
73- temperature = self .temperature if self .temperature > 0 else 1.0 ,
74- top_p = self .top_p if self .temperature > 0 else 1.0 ,
72+ temperature = self .temperature if self .temperature >= 0 else 1.0 ,
73+ top_p = self .top_p if self .top_p >= 0 else 1.0 ,
7574 max_tokens = extra .get ("max_new_tokens" , 2048 ),
75+ repetition_penalty = extra .get ("repetition_penalty" , 1.05 ),
7676 )
7777
78- result_generator = self .engine .generate (full_prompt , sp , request_id = request_id )
79-
8078 try :
79+ result_generator = self .engine .generate (full_prompt , sp , request_id = request_id )
8180 final_output = await asyncio .wait_for (
8281 self ._consume_generator (result_generator ),
8382 timeout = self .timeout
@@ -89,7 +88,7 @@ async def generate_answer(
8988 result_text = final_output .outputs [0 ].text
9089 return result_text
9190
92- except (Exception , asyncio .CancelledError ):
91+ except (Exception , asyncio .CancelledError , asyncio . TimeoutError ):
9392 await self .engine .abort (request_id )
9493 raise
9594
@@ -105,14 +104,14 @@ async def generate_topk_per_token(
105104 logprobs = self .top_k ,
106105 )
107106
108- result_generator = self .engine .generate (full_prompt , sp , request_id = request_id )
109-
110107 try :
108+ result_generator = self .engine .generate (full_prompt , sp , request_id = request_id )
111109 final_output = await asyncio .wait_for (
112110 self ._consume_generator (result_generator ),
113111 timeout = self .timeout
114112 )
115113
114+
116115 if (
117116 not final_output
118117 or not final_output .outputs
@@ -141,7 +140,7 @@ async def generate_topk_per_token(
141140 return [main_token ]
142141 return []
143142
144- except (Exception , asyncio .CancelledError ):
143+ except (Exception , asyncio .CancelledError , asyncio . TimeoutError ):
145144 await self .engine .abort (request_id )
146145 raise
147146
0 commit comments