11import math
2+ from tracemalloc import stop
23import uuid
34from typing import Any , List , Optional
45import asyncio
@@ -42,20 +43,17 @@ def __init__(
4243 )
4344 self .engine = AsyncLLMEngine .from_engine_args (engine_args )
4445 self .timeout = float (timeout )
46+ self .tokenizer = self .engine .engine .tokenizer .tokenizer
4547
46- @staticmethod
47- def _build_inputs (prompt : str , history : Optional [List [str ]] = None ) -> str :
48- msgs = history or []
49- lines = []
50- for m in msgs :
51- if isinstance (m , dict ):
52- role = m .get ("role" , "" )
53- content = m .get ("content" , "" )
54- lines .append (f"{ role } : { content } " )
55- else :
56- lines .append (str (m ))
57- lines .append (prompt )
58- return "\n " .join (lines )
48+ def _build_inputs (self , prompt : str , history : Optional [List [dict ]] = None ) -> Any :
49+ messages = history or []
50+ messages .append ({"role" : "user" , "content" : prompt })
51+
52+ return self .tokenizer .apply_chat_template (
53+ messages ,
54+ tokenize = False ,
55+ add_generation_prompt = True
56+ )
5957
6058 async def _consume_generator (self , generator ):
6159 final_output = None
@@ -70,14 +68,14 @@ async def generate_answer(
7068 request_id = f"graphgen_req_{ uuid .uuid4 ()} "
7169
7270 sp = self .SamplingParams (
73- temperature = self .temperature if self .temperature > 0 else 1.0 ,
74- top_p = self .top_p if self .temperature > 0 else 1.0 ,
71+ temperature = self .temperature if self .temperature >= 0 else 1.0 ,
72+ top_p = self .top_p if self .top_p >= 0 else 1.0 ,
7573 max_tokens = extra .get ("max_new_tokens" , 2048 ),
74+ repetition_penalty = extra .get ("repetition_penalty" , 1.05 ),
7675 )
7776
78- result_generator = self .engine .generate (full_prompt , sp , request_id = request_id )
79-
8077 try :
78+ result_generator = self .engine .generate (full_prompt , sp , request_id = request_id )
8179 final_output = await asyncio .wait_for (
8280 self ._consume_generator (result_generator ),
8381 timeout = self .timeout
@@ -105,9 +103,8 @@ async def generate_topk_per_token(
105103 logprobs = self .top_k ,
106104 )
107105
108- result_generator = self .engine .generate (full_prompt , sp , request_id = request_id )
109-
110106 try :
107+ result_generator = self .engine .generate (full_prompt , sp , request_id = request_id )
111108 final_output = await asyncio .wait_for (
112109 self ._consume_generator (result_generator ),
113110 timeout = self .timeout
0 commit comments