|
11 | 11 | import logging |
12 | 12 | import os |
13 | 13 | from dataclasses import dataclass |
14 | | -from typing import Any, Optional |
| 14 | +from typing import Any |
15 | 15 |
|
16 | | -from transformers import AutoModelForCausalLM, AutoTokenizer # type: ignore |
| 16 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
17 | 17 |
|
18 | 18 | logger = logging.getLogger(__name__) |
19 | 19 |
|
@@ -64,8 +64,8 @@ def __init__( |
64 | 64 | self.max_tokens = max_tokens |
65 | 65 | self.temperature = temperature |
66 | 66 |
|
67 | | - self._model: Optional[Any] = None |
68 | | - self._tokenizer: Optional[Any] = None |
| 67 | + self._model: Any | None = None |
| 68 | + self._tokenizer: Any | None = None |
69 | 69 | self._initialized = False |
70 | 70 |
|
71 | 71 | def _lazy_init(self) -> None: |
@@ -250,24 +250,24 @@ def _generate(self, system_prompt: str, user_prompt: str) -> str: |
250 | 250 | ] |
251 | 251 |
|
252 | 252 | # Apply chat template |
253 | | - text = self._tokenizer.apply_chat_template( # type: ignore |
| 253 | + text = self._tokenizer.apply_chat_template( |
254 | 254 | messages, tokenize=False, add_generation_prompt=True |
255 | 255 | ) |
256 | 256 |
|
257 | 257 | # Tokenize |
258 | | - inputs = self._tokenizer([text], return_tensors="pt").to(self._model.device) # type: ignore |
| 258 | + inputs = self._tokenizer([text], return_tensors="pt").to(self._model.device) |
259 | 259 |
|
260 | 260 | # Generate |
261 | | - outputs = self._model.generate( # type: ignore |
| 261 | + outputs = self._model.generate( |
262 | 262 | **inputs, |
263 | 263 | max_new_tokens=self.max_tokens, |
264 | 264 | temperature=self.temperature if self.temperature > 0 else None, |
265 | 265 | do_sample=self.temperature > 0, |
266 | | - pad_token_id=self._tokenizer.eos_token_id, # type: ignore |
| 266 | + pad_token_id=self._tokenizer.eos_token_id, |
267 | 267 | ) |
268 | 268 |
|
269 | 269 | # Decode |
270 | | - generated_text: str = self._tokenizer.decode(outputs[0], skip_special_tokens=True) # type: ignore |
| 270 | + generated_text: str = self._tokenizer.decode(outputs[0], skip_special_tokens=True) |
271 | 271 |
|
272 | 272 | # Extract response (everything after the user prompt) |
273 | 273 | # This handles the chat template format |
|
0 commit comments