Skip to content

Commit 0f325c8

Browse files
fix: add repetition_penalty
1 parent 952a2d4 commit 0f325c8

1 file changed

Lines changed: 16 additions & 19 deletions

File tree

graphgen/models/llm/local/vllm_wrapper.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
from tracemalloc import stop
23
import uuid
34
from typing import Any, List, Optional
45
import asyncio
@@ -42,20 +43,17 @@ def __init__(
4243
)
4344
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
4445
self.timeout = float(timeout)
46+
self.tokenizer = self.engine.engine.tokenizer.tokenizer
4547

46-
@staticmethod
47-
def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
48-
msgs = history or []
49-
lines = []
50-
for m in msgs:
51-
if isinstance(m, dict):
52-
role = m.get("role", "")
53-
content = m.get("content", "")
54-
lines.append(f"{role}: {content}")
55-
else:
56-
lines.append(str(m))
57-
lines.append(prompt)
58-
return "\n".join(lines)
48+
def _build_inputs(self, prompt: str, history: Optional[List[dict]] = None) -> Any:
49+
messages = history or []
50+
messages.append({"role": "user", "content": prompt})
51+
52+
return self.tokenizer.apply_chat_template(
53+
messages,
54+
tokenize=False,
55+
add_generation_prompt=True
56+
)
5957

6058
async def _consume_generator(self, generator):
6159
final_output = None
@@ -70,14 +68,14 @@ async def generate_answer(
7068
request_id = f"graphgen_req_{uuid.uuid4()}"
7169

7270
sp = self.SamplingParams(
73-
temperature=self.temperature if self.temperature > 0 else 1.0,
74-
top_p=self.top_p if self.temperature > 0 else 1.0,
71+
temperature=self.temperature if self.temperature >= 0 else 1.0,
72+
top_p=self.top_p if self.top_p >= 0 else 1.0,
7573
max_tokens=extra.get("max_new_tokens", 2048),
74+
repetition_penalty=extra.get("repetition_penalty", 1.05),
7675
)
7776

78-
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
79-
8077
try:
78+
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
8179
final_output = await asyncio.wait_for(
8280
self._consume_generator(result_generator),
8381
timeout=self.timeout
@@ -105,9 +103,8 @@ async def generate_topk_per_token(
105103
logprobs=self.top_k,
106104
)
107105

108-
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
109-
110106
try:
107+
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
111108
final_output = await asyncio.wait_for(
112109
self._consume_generator(result_generator),
113110
timeout=self.timeout

0 commit comments

Comments
 (0)