Skip to content

Commit ba85520

Browse files
Adds support for different LLMs
1 parent d729167 commit ba85520

4 files changed

Lines changed: 95 additions & 13 deletions

File tree

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,7 @@ bin/gpu_benchmark
5757
test/
5858
llmtests/
5959

60-
scratch/
60+
scratch/
61+
62+
models.yaml
63+
.env

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ dependencies = [
4040
"stringcase",
4141
"filelock",
4242
"pathos",
43+
"python-dotenv",
4344
]
4445
dynamic = ["version", "entry-points", "scripts"]
4546

wfcommons/wfbench/translator/llm_translator.py

Lines changed: 89 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,20 @@
55
from wfcommons.wfbench.translator and implements the required interface.
66
"""
77

8+
import os
9+
import re
10+
from pathlib import Path
11+
812
import requests
9-
from wfcommons.wfbench.bench import WorkflowBenchmark
13+
import yaml
14+
from dotenv import load_dotenv
1015
from typing import Optional, Dict, Any, List
1116
from wfcommons.wfbench.translator.utils.llm_client import LLMClient
1217

18+
load_dotenv() # loads .env from cwd (project root)
19+
20+
MODELS_YAML = Path("models.yaml")
21+
1322

1423
class LLMTranslator():
1524
"""
@@ -22,35 +31,104 @@ class LLMTranslator():
2231
"""
2332

2433
def __init__(self,
25-
llm_client: LLMClient,
26-
examples_instances: Optional[List[str]],
34+
llm_client: LLMClient | None = None,
35+
model_name: str | None = None,
36+
models_file: str | Path | None = None,
37+
examples_instances: Optional[List[str]] = None,
2738
num_examples: int = 3,
2839
system_prompt: Optional[str] = None,
2940
**kwargs
3041
):
3142
"""
3243
Parameters
3344
----------
34-
llm_client : Any
35-
An object with `.complete(prompt: str) -> str`.
45+
llm_client : LLMClient, optional
46+
A pre-configured LLMClient instance. Either this or
47+
``model_name`` must be provided.
48+
model_name : str, optional
49+
Key from models.yaml (e.g. "qwen3", "ollama/llama3").
50+
The matching config is used to build an LLMClient automatically.
51+
models_file : str or Path, optional
52+
Path to a custom models YAML file. Defaults to the
53+
``models.yaml`` shipped alongside this module.
3654
example_instances : List[str]
3755
URLs pointing to translator examples or benchmarks:
3856
- raw GitHub links
3957
- JSON traces
4058
- scripts
4159
num_examples : int, optional
42-
Number of example instances to include in the prompt.
60+
Number of example instances to include in the prompt.
4361
system_prompt : str, optional
4462
Override the default system instructions for the LLM.
4563
kwargs : dict
4664
Additional parameters passed to the parent Translator if needed.
4765
"""
4866
super().__init__(**kwargs)
4967

68+
if llm_client is None and model_name is None:
69+
raise ValueError("Provide either llm_client or model_name.")
70+
if llm_client is not None and model_name is not None:
71+
raise ValueError("Provide only one of llm_client or model_name, not both.")
72+
73+
if model_name is not None:
74+
llm_client = self._client_from_yaml(model_name, models_file)
75+
5076
self.llm = llm_client
5177
self.examples_instances = examples_instances
5278
self.num_examples = num_examples
5379
self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
80+
81+
# ------------------------------------------------------------------ #
82+
# YAML helpers #
83+
# ------------------------------------------------------------------ #
84+
85+
@staticmethod
86+
def available_models(models_file: str | Path | None = None) -> list[str]:
87+
"""Return the list of model keys defined in models.yaml."""
88+
cfg = LLMTranslator._load_models_yaml(models_file)
89+
return list(cfg.keys())
90+
91+
@staticmethod
92+
def _load_models_yaml(models_file: str | Path | None = None) -> dict:
93+
path = Path(models_file) if models_file else MODELS_YAML
94+
with open(path) as f:
95+
return yaml.safe_load(f)
96+
97+
@staticmethod
98+
def _resolve_env(value: str) -> str:
99+
"""Replace ``${VAR}`` placeholders with environment variables."""
100+
def _replace(m):
101+
var = m.group(1)
102+
val = os.environ.get(var)
103+
if val is None:
104+
raise EnvironmentError(
105+
f"Environment variable '{var}' is not set "
106+
f"(required by models.yaml)."
107+
)
108+
return val
109+
return re.sub(r"\$\{(\w+)\}", _replace, value)
110+
111+
@staticmethod
112+
def _client_from_yaml(model_name: str,
113+
models_file: str | Path | None = None) -> LLMClient:
114+
cfg = LLMTranslator._load_models_yaml(models_file)
115+
if model_name not in cfg:
116+
raise KeyError(
117+
f"Model '{model_name}' not found in models.yaml. "
118+
f"Available: {list(cfg.keys())}"
119+
)
120+
entry = cfg[model_name]
121+
api_key = LLMTranslator._resolve_env(str(entry["api_key"]))
122+
base_url = entry.get("base_url")
123+
if base_url and base_url != "null":
124+
base_url = LLMTranslator._resolve_env(str(base_url))
125+
else:
126+
base_url = None
127+
return LLMClient(
128+
model=entry["model"],
129+
api_key=api_key,
130+
base_url=base_url,
131+
)
54132

55133
def _load_examples(self,
56134
path_list: List[str] ,
@@ -128,16 +206,17 @@ def translate(self, trace, metadata=None, json_schema: dict | None = None, **kwa
128206
metadata=metadata,
129207
)
130208

131-
output = self.llm.complete(
132-
prompt,
133-
response_format={
209+
response_format = None
210+
if json_schema is not None:
211+
response_format = {
134212
"type": "json_schema",
135213
"json_schema": {
136214
"name": "WfFormat",
137215
"schema": json_schema
138216
}
139217
}
140-
)
218+
219+
output = self.llm.complete(prompt, response_format=response_format)
141220
return output
142221

143222
def _retrieve_examples(self, trace_text: str):

wfcommons/wfbench/translator/utils/llm_client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ def complete(self,
6666
Format specification for structured output (e.g., JSON schema).
6767
Note: Not all providers support this parameter.
6868
temperature : float, optional
69-
Sampling temperature (0.0 = deterministic). Defaults to 0.0.
70-
69+
Sampling temperature (default: 0.0 for deterministic output).
7170
Returns
7271
-------
7372
str

0 commit comments

Comments
 (0)