Skip to content

Commit e1726ad

Browse files
feat: add support for the original mt-bench (#21)
* Add llamacpp dependency and update gitignore with generated directories * Add documentation for llamacpp in Readme * Document direnv usage for environment variables management * narrow down transformers dependency to fix version mismatch * Add max_model_len param for VLLM in order to prevent OOM errors * Fix completion loading and EuroLLM-9B example - Updated README to use EuroLLM-Instruct because the base (EuroLLM-9B) doesn't have a chat template and throws error. - Added functionality to load pre-existing dataset completions for models. Was throwing error previously, becuase it was considering the model as a provider. * Remove `direnv` documentation * Revert stylistic (formatting) changes and add more documentation for the new `max_model_len` and related parameters * Rename OPENJURY_EVAL_DATA to OPENJURY_DATA * Revert changes in gitignore * Handle models with max_position_embeddings when we pass max_model_len - Moved max_model_len and chat_template to **model_kwargs for readability. - Adjusted ChatVLLM initialization to cap max_model_len based on model's max_position_embeddings. - Added warnings for potential max_model_len issues. * Revert EuroLLM-9B-Instruct to EuroLLM-9B since there is a default chat template now * fix tests - mock external api calls - add safety check for content in completions * Change test github workflow to use uv instead of pip for a more robust dependency resolution * Move dev dependencies to dependency-group * Revert comment removal * Add pre-commit hook * add project scripts and move slurmpilot to dev group - moved slurmpilot to dev group since it doesn't have a published version on Pypi and doesn't allow we are not allowed to publish Openjury on Pypi otherwise * fix LlamaCpp bug with ChatTemplate - There was a halting issue with LlamaCpp since the model was not emitting EOS token and doesn't call Llama.reset() between calls (turns), causing a KV cache position mismatch crash so ChatLlamaCppModel was created as a custom wrapper to fix this - BaseLocalModel was extracted as common logic for ChatLlamaCppModel and ChatVLLM * Add MT-Bench multi-turn evaluation support - Implement MT-Bench loader and multi-turn generation/judging logic. - Add paper-aligned prompt templates while keeping the score-based evaluation to be consistent with OpenJury. - Support reference answers, per-turn breakdowns, and swap mode. - Add comprehensive MT-Bench pipeline tests. * fix result formatting * remove double environment variable * remove accidental duplications * Refactor - Implemented a new function to download MT-Bench questions and GPT-4 reference answers, with fallback mechanisms for missing references. - Remove duplication. * Remove duplication between prompt templates * add temperature argument * add option for making mt-bench consistent with the original one from fastchat * remove redundant print statement * move mt-bench logic from the entrypoint * Remove stale unused entries for fastchat mode * Refactor mt-bench eval helpers into shared runtime module * move cli args and parsing to separate util to remove dependencies on entrypoint * refactor to address comments on PR * remove openjury mode for mt-bench keeping only the original version * Restore code and fix after merge/refactor * format * fix ci
1 parent 223db1b commit e1726ad

20 files changed

Lines changed: 1630 additions & 13 deletions

judgearena/config.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
"""CLI argument configuration for generation and evaluation entrypoints."""
2+
3+
import argparse
4+
import json
5+
from dataclasses import dataclass, field
6+
7+
8+
@dataclass
9+
class CliArgs:
10+
dataset: str
11+
model_A: str
12+
model_B: str
13+
judge_model: str
14+
15+
n_instructions: int | None = None
16+
provide_explanation: bool = False
17+
swap_mode: str = "fixed"
18+
ignore_cache: bool = False
19+
use_tqdm: bool = False
20+
truncate_all_input_chars: int = 8192
21+
max_out_tokens_models: int = 32768
22+
max_out_tokens_judge: int = 32768
23+
max_model_len: int | None = None
24+
chat_template: str | None = None
25+
result_folder: str = "results"
26+
engine_kwargs: dict = field(default_factory=dict)
27+
28+
def __post_init__(self):
29+
supported_modes = ["fixed", "both"]
30+
assert self.swap_mode in supported_modes, (
31+
f"Only {supported_modes} modes are supported but got {self.swap_mode}."
32+
)
33+
34+
@classmethod
35+
def parse_args(cls):
36+
parser = argparse.ArgumentParser(
37+
prog="Generate completion and evaluate with a judge",
38+
)
39+
parser.add_argument(
40+
"--dataset",
41+
help="The dataset to use. For instance `alpaca-eval`, `arena-hard`, `m-arena-hard-EU` for instruction "
42+
"tuning cases or `french-contexts`, `spanish-contexts` for base models.",
43+
)
44+
parser.add_argument(
45+
"--model_A",
46+
required=True,
47+
help="Name of the LLM to use for a generation, must be a valid choice for `generation_provider`",
48+
)
49+
parser.add_argument(
50+
"--model_B",
51+
required=True,
52+
help="Name of the LLM to use for a generation, must be a valid choice for `generation_provider`",
53+
)
54+
parser.add_argument(
55+
"--judge_model",
56+
required=True,
57+
help="Name of the LLM to use, for instance `Together/meta-llama/Meta-Llama-3-70B-Instruct-Turbo`, "
58+
"`VLLM/meta-llama/Meta-Llama-3-70B-Instruct-Turbo`, `LangChain/LocalPath` etc",
59+
)
60+
parser.add_argument(
61+
"--n_instructions",
62+
type=int,
63+
required=False,
64+
)
65+
parser.add_argument(
66+
"--provide_explanation",
67+
action="store_true",
68+
help="If specified, judge will provide explanation before making a judgement. Does not necessarily improve"
69+
"the accuracy of the judge but enables some result interpretation.",
70+
)
71+
parser.add_argument(
72+
"--swap_mode",
73+
type=str,
74+
choices=["fixed", "both"],
75+
default="fixed",
76+
help="Model comparison order mode. 'fixed': always use model order A-B. 'both': correct for model order "
77+
"bias by evaluating each instruction twice, once as A-B and once as B-A, and average. This helps account "
78+
"for judge position bias. Default is 'fixed'.",
79+
)
80+
parser.add_argument(
81+
"--ignore_cache",
82+
action="store_true",
83+
help="If specified, ignore cache of previous completions.",
84+
)
85+
parser.add_argument(
86+
"--use_tqdm",
87+
action="store_true",
88+
help="If specified, use tqdm, does not work with all model providers, vLLM in particular.",
89+
)
90+
parser.add_argument(
91+
"--result_folder",
92+
type=str,
93+
required=False,
94+
default="results",
95+
help="The folder to save the results. Defaults to `results`. Evaluation results will be saved in"
96+
" `[result_folder]/[evaluation_name]`.",
97+
)
98+
parser.add_argument(
99+
"--truncate_all_input_chars",
100+
type=int,
101+
required=False,
102+
default=8192,
103+
help="Character-level truncation applied before tokenization: truncates each instruction "
104+
"before model A/B generation and truncates each completion before judge evaluation.",
105+
)
106+
parser.add_argument(
107+
"--max_out_tokens_models",
108+
type=int,
109+
required=False,
110+
default=32768,
111+
help=(
112+
"Generation token budget for each model A/B response. For VLLM, keep this <= "
113+
"--max_model_len (if provided)."
114+
),
115+
)
116+
parser.add_argument(
117+
"--max_out_tokens_judge",
118+
type=int,
119+
required=False,
120+
default=32768,
121+
help=(
122+
"Generation token budget for the judge response (reasoning + scores). For "
123+
"VLLM, keep this <= --max_model_len (if provided)."
124+
),
125+
)
126+
parser.add_argument(
127+
"--max_model_len",
128+
type=int,
129+
required=False,
130+
default=None,
131+
help=(
132+
"Optional total context window for VLLM models (prompt + generation). This is "
133+
"independent from --max_out_tokens_models/--max_out_tokens_judge, which only cap "
134+
"generated tokens. This is useful on smaller GPUs to avoid OOM."
135+
),
136+
)
137+
parser.add_argument(
138+
"--chat_template",
139+
type=str,
140+
required=False,
141+
default=None,
142+
help="Jinja2 chat template string to use instead of the model's tokenizer template. "
143+
"If not provided, ChatML is used as fallback for models without a chat template.",
144+
)
145+
parser.add_argument(
146+
"--engine_kwargs",
147+
type=str,
148+
required=False,
149+
default="{}",
150+
help=(
151+
"JSON dict of engine-specific kwargs forwarded to the underlying engine. "
152+
'Example for vLLM: \'{"tensor_parallel_size": 2, "gpu_memory_utilization": 0.9}\'.'
153+
),
154+
)
155+
args = parser.parse_args()
156+
157+
try:
158+
engine_kwargs = json.loads(args.engine_kwargs) if args.engine_kwargs else {}
159+
if not isinstance(engine_kwargs, dict):
160+
raise ValueError("engine_kwargs must be a JSON object")
161+
except Exception as e:
162+
raise SystemExit(f"Failed to parse --engine_kwargs: {e}") from e
163+
164+
return cls(
165+
dataset=args.dataset,
166+
model_A=args.model_A,
167+
model_B=args.model_B,
168+
judge_model=args.judge_model,
169+
n_instructions=args.n_instructions,
170+
provide_explanation=args.provide_explanation,
171+
swap_mode=args.swap_mode,
172+
ignore_cache=args.ignore_cache,
173+
use_tqdm=args.use_tqdm,
174+
truncate_all_input_chars=args.truncate_all_input_chars,
175+
max_out_tokens_models=args.max_out_tokens_models,
176+
max_out_tokens_judge=args.max_out_tokens_judge,
177+
max_model_len=args.max_model_len,
178+
chat_template=args.chat_template,
179+
result_folder=args.result_folder,
180+
engine_kwargs=engine_kwargs,
181+
)

judgearena/eval_utils.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
"""Shared evaluation runtime helpers used by entrypoints and benchmark pipelines."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
7+
import pandas as pd
8+
9+
from judgearena.evaluate import PairScore, annotate_battles
10+
from judgearena.utils import compute_pref_summary
11+
12+
13+
def print_results(results):
14+
"""Print battle results in a readable format."""
15+
print("\n" + "=" * 60)
16+
print("🏆 MODEL BATTLE RESULTS 🏆".center(60))
17+
print(f"📊 Dataset: {results['dataset']}")
18+
print(
19+
f"🤖 Competitors: Model A: {results['model_A']} vs Model B: {results['model_B']}"
20+
)
21+
print(f"⚖️ Judge: {results['judge_model']}")
22+
print("📈 Results Summary:")
23+
print(f" Total Battles: {results['num_battles']}")
24+
print(f" Win Rate (A): {results['winrate']:.1%}")
25+
print(f" ✅ Wins: {results['num_wins']}")
26+
print(f" ❌ Losses: {results['num_losses']}")
27+
print(f" 🤝 Ties: {results['num_ties']}")
28+
if results.get("num_missing", 0) > 0:
29+
print(f" ❓ Missing: {results['num_missing']}")
30+
31+
per_category = results.get("per_category")
32+
if per_category:
33+
print("\nPer-Category Breakdown:")
34+
print(
35+
f" {'Category':<14} | {'Win Rate(A)':>11} | {'Wins':>4} | {'Losses':>6} | {'Ties':>4}"
36+
)
37+
print(f" {'-' * 14}-+-{'-' * 11}-+-{'-' * 4}-+-{'-' * 6}-+-{'-' * 4}")
38+
for cat, stats in sorted(per_category.items()):
39+
print(
40+
f" {cat:<14} | {stats['winrate']:>11.1%} | "
41+
f"{stats['num_wins']:>4} | {stats['num_losses']:>6} | {stats['num_ties']:>4}"
42+
)
43+
44+
per_turn = results.get("per_turn")
45+
if per_turn:
46+
print("\nPer-Turn Breakdown:")
47+
for turn, stats in sorted(per_turn.items()):
48+
print(
49+
f" Turn {turn} Win Rate(A): {stats['winrate']:.1%} "
50+
f"(W:{stats['num_wins']} L:{stats['num_losses']} T:{stats['num_ties']})"
51+
)
52+
print("=" * 60 + "\n")
53+
54+
55+
def _compute_grouped_stats(
56+
preferences: pd.Series,
57+
metadata: list[dict[str, object]],
58+
group_by: str,
59+
) -> dict[object, dict[str, float | int]]:
60+
grouped: dict[object, list[float]] = {}
61+
for meta, pref in zip(metadata, preferences, strict=True):
62+
key = meta.get(group_by)
63+
if key is None:
64+
continue
65+
grouped.setdefault(key, []).append(pref)
66+
return {key: compute_pref_summary(pd.Series(vals)) for key, vals in grouped.items()}
67+
68+
69+
def _parse_preferences_from_annotations(
70+
annotations: list,
71+
score_parser: PairScore,
72+
) -> pd.Series:
73+
return pd.Series(
74+
[
75+
score_parser.parse_model_raw(annotation.judge_completion)
76+
for annotation in annotations
77+
]
78+
)
79+
80+
81+
@dataclass
82+
class JudgeAnnotationResult:
83+
annotations: list
84+
annotations_reversed: list
85+
metadata_for_annotations: list[dict[str, object]]
86+
metadata_for_reversed_annotations: list[dict[str, object]]
87+
preferences: pd.Series
88+
combined_metadata: list[dict[str, object]]
89+
90+
91+
def _make_judge_annotation(
92+
*,
93+
judge_chat_model,
94+
instructions: list[str],
95+
completions_A: list[str],
96+
completions_B: list[str],
97+
metadata: list[dict[str, object]],
98+
score_parser: PairScore,
99+
provide_explanation: bool,
100+
swap_mode: str,
101+
truncate_input_chars: int | None,
102+
use_tqdm: bool,
103+
system_prompt: str | None = None,
104+
user_prompt_template: str | None = None,
105+
) -> JudgeAnnotationResult:
106+
if not instructions:
107+
raise ValueError("instructions must be non-empty")
108+
109+
annotations = annotate_battles(
110+
judge_chat_model=judge_chat_model,
111+
instructions=instructions,
112+
completions_A=completions_A,
113+
completions_B=completions_B,
114+
provide_explanation=provide_explanation,
115+
system_prompt=system_prompt,
116+
user_prompt_template=user_prompt_template,
117+
truncate_input_chars=truncate_input_chars,
118+
use_tqdm=use_tqdm,
119+
)
120+
preference_parts = [_parse_preferences_from_annotations(annotations, score_parser)]
121+
122+
annotations_reversed: list = []
123+
metadata_for_reversed_annotations: list[dict[str, object]] = []
124+
combined_metadata = list(metadata)
125+
126+
if swap_mode == "both":
127+
print("Correction for judge bias towards a certain model position is set.")
128+
print("Evaluating completions with models reversed.")
129+
annotations_reversed = annotate_battles(
130+
judge_chat_model=judge_chat_model,
131+
instructions=instructions,
132+
completions_A=completions_B,
133+
completions_B=completions_A,
134+
provide_explanation=provide_explanation,
135+
system_prompt=system_prompt,
136+
user_prompt_template=user_prompt_template,
137+
truncate_input_chars=truncate_input_chars,
138+
use_tqdm=use_tqdm,
139+
)
140+
prefs_reversed = _parse_preferences_from_annotations(
141+
annotations_reversed, score_parser
142+
)
143+
preference_parts.append(1 - prefs_reversed)
144+
metadata_for_reversed_annotations = list(metadata)
145+
combined_metadata.extend(metadata)
146+
147+
preferences = pd.concat(preference_parts).reset_index(drop=True)
148+
return JudgeAnnotationResult(
149+
annotations=annotations,
150+
annotations_reversed=annotations_reversed,
151+
metadata_for_annotations=list(metadata),
152+
metadata_for_reversed_annotations=metadata_for_reversed_annotations,
153+
preferences=preferences,
154+
combined_metadata=combined_metadata,
155+
)

judgearena/evaluate.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,30 +51,46 @@ def get_regexp_match(self, s: str, regex: str, group_index: int = 1):
5151
return float(m.group(group_index).strip(" "))
5252

5353

54+
_COMPLETION_LABEL_SINGLE = "Answer"
55+
_COMPLETION_LABEL_MULTI_TURN = "Conversation with User"
56+
_EXPLANATION_SUFFIX = ", first starts with an explanation of your judgement"
57+
_SCORE_FENCE = "\n```"
58+
59+
5460
def load_judge_system_and_user_prompt(
5561
provide_explanation: bool = True,
62+
multi_turn: bool = False,
5663
) -> tuple[str, str]:
57-
# Prepare judge
58-
with open(Path(__file__).parent / "prompts" / "system-prompt.txt") as f:
59-
system_prompt = str(f.read())
64+
prompts_dir = Path(__file__).parent / "prompts"
65+
system_prompt = (prompts_dir / "system-prompt.txt").read_text()
6066

6167
prompt_filename = (
6268
"prompt-with-explanation.txt" if provide_explanation else "prompt.txt"
6369
)
64-
with open(Path(__file__).parent / "prompts" / prompt_filename) as f:
65-
user_prompt_template = str(f.read())
70+
user_prompt_template = (prompts_dir / prompt_filename).read_text()
71+
user_prompt_template = user_prompt_template.replace(
72+
"{completion_label}",
73+
_COMPLETION_LABEL_MULTI_TURN if multi_turn else _COMPLETION_LABEL_SINGLE,
74+
)
75+
user_prompt_template = user_prompt_template.replace(
76+
"{explanation_suffix}",
77+
_EXPLANATION_SUFFIX if provide_explanation else _SCORE_FENCE,
78+
)
6679

6780
return system_prompt, user_prompt_template
6881

6982

7083
def resolve_judge_prompts(
7184
*,
7285
provide_explanation: bool,
86+
multi_turn: bool = False,
7387
system_prompt: str | None = None,
7488
user_prompt_template: str | None = None,
7589
) -> tuple[str, str]:
7690
default_system_prompt, default_user_prompt_template = (
77-
load_judge_system_and_user_prompt(provide_explanation=provide_explanation)
91+
load_judge_system_and_user_prompt(
92+
provide_explanation=provide_explanation, multi_turn=multi_turn
93+
)
7894
)
7995
return (
8096
system_prompt if system_prompt is not None else default_system_prompt,

0 commit comments

Comments
 (0)