Skip to content

Commit d1b3b6e

Browse files
fix: add repetition_penalty (#155)
* fix: add repetition_penalty * Potential fix for pull request finding 'Unused import' Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> * fix: catch timeouterror * fix: fix no attribute timeout * fix: change vllm timeout from 300 to 600 --------- Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
1 parent 952a2d4 commit d1b3b6e

1 file changed

Lines changed: 22 additions & 23 deletions

File tree

graphgen/models/llm/local/vllm_wrapper.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(
2020
temperature: float = 0.6,
2121
top_p: float = 1.0,
2222
top_k: int = 5,
23-
timeout: float = 300,
23+
timeout: float = 600,
2424
**kwargs: Any,
2525
):
2626
super().__init__(temperature=temperature, top_p=top_p, top_k=top_k, **kwargs)
@@ -42,25 +42,24 @@ def __init__(
4242
)
4343
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
4444
self.timeout = float(timeout)
45+
self.tokenizer = self.engine.engine.tokenizer.tokenizer
4546

46-
@staticmethod
47-
def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
48-
msgs = history or []
49-
lines = []
50-
for m in msgs:
51-
if isinstance(m, dict):
52-
role = m.get("role", "")
53-
content = m.get("content", "")
54-
lines.append(f"{role}: {content}")
55-
else:
56-
lines.append(str(m))
57-
lines.append(prompt)
58-
return "\n".join(lines)
47+
def _build_inputs(self, prompt: str, history: Optional[List[dict]] = None) -> Any:
48+
messages = history or []
49+
messages.append({"role": "user", "content": prompt})
50+
51+
return self.tokenizer.apply_chat_template(
52+
messages,
53+
tokenize=False,
54+
add_generation_prompt=True
55+
)
5956

6057
async def _consume_generator(self, generator):
6158
final_output = None
6259
async for request_output in generator:
63-
final_output = request_output
60+
if request_output.finished:
61+
final_output = request_output
62+
break
6463
return final_output
6564

6665
async def generate_answer(
@@ -70,14 +69,14 @@ async def generate_answer(
7069
request_id = f"graphgen_req_{uuid.uuid4()}"
7170

7271
sp = self.SamplingParams(
73-
temperature=self.temperature if self.temperature > 0 else 1.0,
74-
top_p=self.top_p if self.temperature > 0 else 1.0,
72+
temperature=self.temperature if self.temperature >= 0 else 1.0,
73+
top_p=self.top_p if self.top_p >= 0 else 1.0,
7574
max_tokens=extra.get("max_new_tokens", 2048),
75+
repetition_penalty=extra.get("repetition_penalty", 1.05),
7676
)
7777

78-
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
79-
8078
try:
79+
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
8180
final_output = await asyncio.wait_for(
8281
self._consume_generator(result_generator),
8382
timeout=self.timeout
@@ -89,7 +88,7 @@ async def generate_answer(
8988
result_text = final_output.outputs[0].text
9089
return result_text
9190

92-
except (Exception, asyncio.CancelledError):
91+
except (Exception, asyncio.CancelledError, asyncio.TimeoutError):
9392
await self.engine.abort(request_id)
9493
raise
9594

@@ -105,14 +104,14 @@ async def generate_topk_per_token(
105104
logprobs=self.top_k,
106105
)
107106

108-
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
109-
110107
try:
108+
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
111109
final_output = await asyncio.wait_for(
112110
self._consume_generator(result_generator),
113111
timeout=self.timeout
114112
)
115113

114+
116115
if (
117116
not final_output
118117
or not final_output.outputs
@@ -141,7 +140,7 @@ async def generate_topk_per_token(
141140
return [main_token]
142141
return []
143142

144-
except (Exception, asyncio.CancelledError):
143+
except (Exception, asyncio.CancelledError, asyncio.TimeoutError):
145144
await self.engine.abort(request_id)
146145
raise
147146

0 commit comments

Comments
 (0)