Skip to content

Commit ab8ee12

Browse files
committed
fix(lint): run ruff format on all files for CI compliance
1 parent 79e1fbd commit ab8ee12

4 files changed

Lines changed: 40 additions & 38 deletions

File tree

src/alignrl/eval.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ def _resolve_preset(self) -> EvalConfig:
4646
if self.preset is not None:
4747
if self.preset not in BENCHMARK_PRESETS:
4848
raise ValueError(
49-
f"Unknown preset {self.preset!r}. "
50-
f"Available: {', '.join(BENCHMARK_PRESETS)}"
49+
f"Unknown preset {self.preset!r}. Available: {', '.join(BENCHMARK_PRESETS)}"
5150
)
5251
self.tasks = BENCHMARK_PRESETS[self.preset]
5352
else:
@@ -60,7 +59,8 @@ def parse_results(raw: dict[str, Any], model_name: str, stage: str) -> EvalResul
6059
benchmarks: dict[str, dict[str, float]] = {}
6160
for task_name, metrics in raw.get("results", {}).items():
6261
benchmarks[task_name] = {
63-
k: v for k, v in metrics.items()
62+
k: v
63+
for k, v in metrics.items()
6464
if isinstance(v, (int, float)) and not isinstance(v, bool)
6565
}
6666
return EvalResult(model_name=model_name, stage=stage, benchmarks=benchmarks)

tests/test_cli.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ def test_train_sft_stage(self, tmp_path) -> None:
4545
args = argparse.Namespace(config=str(config_path), stage="sft", push=None)
4646

4747
mock_runner = MagicMock()
48-
mock_runner.train.return_value = MagicMock(
49-
output_dir=tmp_path, metrics={"train_loss": 0.5}
50-
)
48+
mock_runner.train.return_value = MagicMock(output_dir=tmp_path, metrics={"train_loss": 0.5})
5149

52-
with patch("alignrl.sft.SFTRunner", return_value=mock_runner) as mock_cls, \
53-
patch("alignrl.sft.SFTConfig") as mock_cfg_cls:
50+
with (
51+
patch("alignrl.sft.SFTRunner", return_value=mock_runner) as mock_cls,
52+
patch("alignrl.sft.SFTConfig") as mock_cfg_cls,
53+
):
5454
mock_cfg_cls.from_yaml.return_value = MagicMock()
5555
cmd_train(args)
5656
mock_cls.assert_called_once()
@@ -62,12 +62,12 @@ def test_train_grpo_stage(self, tmp_path) -> None:
6262
args = argparse.Namespace(config=str(config_path), stage="grpo", push=None)
6363

6464
mock_runner = MagicMock()
65-
mock_runner.train.return_value = MagicMock(
66-
output_dir=tmp_path, metrics={"train_loss": 0.3}
67-
)
65+
mock_runner.train.return_value = MagicMock(output_dir=tmp_path, metrics={"train_loss": 0.3})
6866

69-
with patch("alignrl.grpo.GRPORunner", return_value=mock_runner), \
70-
patch("alignrl.grpo.GRPOConfig") as mock_cfg_cls:
67+
with (
68+
patch("alignrl.grpo.GRPORunner", return_value=mock_runner),
69+
patch("alignrl.grpo.GRPOConfig") as mock_cfg_cls,
70+
):
7171
mock_cfg_cls.from_yaml.return_value = MagicMock()
7272
cmd_train(args)
7373
mock_runner.train.assert_called_once()
@@ -78,12 +78,12 @@ def test_train_dpo_stage(self, tmp_path) -> None:
7878
args = argparse.Namespace(config=str(config_path), stage="dpo", push=None)
7979

8080
mock_runner = MagicMock()
81-
mock_runner.train.return_value = MagicMock(
82-
output_dir=tmp_path, metrics={"train_loss": 0.2}
83-
)
81+
mock_runner.train.return_value = MagicMock(output_dir=tmp_path, metrics={"train_loss": 0.2})
8482

85-
with patch("alignrl.dpo.DPORunner", return_value=mock_runner), \
86-
patch("alignrl.dpo.DPOConfig") as mock_cfg_cls:
83+
with (
84+
patch("alignrl.dpo.DPORunner", return_value=mock_runner),
85+
patch("alignrl.dpo.DPOConfig") as mock_cfg_cls,
86+
):
8787
mock_cfg_cls.from_yaml.return_value = MagicMock()
8888
cmd_train(args)
8989
mock_runner.train.assert_called_once()
@@ -116,8 +116,10 @@ def test_eval_creates_output(self, tmp_path) -> None:
116116
wandb=False,
117117
)
118118

119-
with patch("alignrl.eval.EvalRunner", return_value=mock_runner), \
120-
patch("alignrl.eval.EvalConfig"):
119+
with (
120+
patch("alignrl.eval.EvalRunner", return_value=mock_runner),
121+
patch("alignrl.eval.EvalConfig"),
122+
):
121123
cmd_eval(args)
122124
mock_runner.evaluate.assert_called_once_with(stage="base")
123125
assert (tmp_path / "results" / "eval_base.json").exists()

tests/test_demo.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@ def test_creates_gradio_app(self) -> None:
1515
mock_server = MagicMock()
1616
mock_server.generate.return_value = "test output"
1717

18-
with patch.dict("sys.modules", {"gradio": mock_gr}), \
19-
patch("alignrl.demo.ModelServer", return_value=mock_server), \
20-
patch("alignrl.demo.InferenceConfig"):
18+
with (
19+
patch.dict("sys.modules", {"gradio": mock_gr}),
20+
patch("alignrl.demo.ModelServer", return_value=mock_server),
21+
patch("alignrl.demo.InferenceConfig"),
22+
):
2123
mock_server.load = MagicMock()
2224
app = create_demo(stages={"base": None}, model_name="test-model")
2325
assert app is not None
@@ -31,9 +33,11 @@ def test_multiple_stages(self) -> None:
3133

3234
mock_server = MagicMock()
3335

34-
with patch.dict("sys.modules", {"gradio": mock_gr}), \
35-
patch("alignrl.demo.ModelServer", return_value=mock_server), \
36-
patch("alignrl.demo.InferenceConfig"):
36+
with (
37+
patch.dict("sys.modules", {"gradio": mock_gr}),
38+
patch("alignrl.demo.ModelServer", return_value=mock_server),
39+
patch("alignrl.demo.InferenceConfig"),
40+
):
3741
create_demo(
3842
stages={"base": None, "sft": "./sft", "grpo": "./grpo"},
3943
model_name="test-model",
@@ -65,9 +69,11 @@ def capture_click(fn, **kwargs):
6569
mock_button.click.side_effect = capture_click
6670
mock_gr.Button.return_value = mock_button
6771

68-
with patch.dict("sys.modules", {"gradio": mock_gr}), \
69-
patch("alignrl.demo.ModelServer", return_value=mock_server), \
70-
patch("alignrl.demo.InferenceConfig"):
72+
with (
73+
patch.dict("sys.modules", {"gradio": mock_gr}),
74+
patch("alignrl.demo.ModelServer", return_value=mock_server),
75+
patch("alignrl.demo.InferenceConfig"),
76+
):
7177
create_demo(stages={"base": None, "sft": "./sft"}, model_name="test")
7278

7379
# Now call the captured respond_all function

tests/test_eval.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,7 @@ def test_save_results(self, tmp_path: Path) -> None:
125125
assert comparison["gsm8k"]["sft"]["exact_match"] == 0.45
126126

127127
def test_save_results_creates_dir(self, tmp_path: Path) -> None:
128-
result = EvalResult(
129-
model_name="qwen", stage="base", benchmarks={"arc": {"acc": 0.5}}
130-
)
128+
result = EvalResult(model_name="qwen", stage="base", benchmarks={"arc": {"acc": 0.5}})
131129
cfg = EvalConfig()
132130
runner = EvalRunner(cfg)
133131
nested = tmp_path / "deep" / "nested" / "dir"
@@ -155,9 +153,7 @@ def test_evaluate_all_stages_restores_config(self) -> None:
155153
cfg = EvalConfig(adapter_path="original")
156154
runner = EvalRunner(cfg)
157155

158-
mock_result = EvalResult(
159-
model_name="qwen", stage="base", benchmarks={}
160-
)
156+
mock_result = EvalResult(model_name="qwen", stage="base", benchmarks={})
161157

162158
with patch.object(runner, "evaluate", return_value=mock_result):
163159
runner.evaluate_all_stages({"base": None, "sft": "./adapter"})
@@ -167,9 +163,7 @@ def test_evaluate_builds_model_args(self) -> None:
167163
cfg = EvalConfig(model_name="test-model", load_in_4bit=True, adapter_path="./adapter")
168164
runner = EvalRunner(cfg)
169165

170-
mock_raw = {
171-
"results": {"gsm8k": {"exact_match,strict-match": 0.50}}
172-
}
166+
mock_raw = {"results": {"gsm8k": {"exact_match,strict-match": 0.50}}}
173167

174168
mock_lm = MagicMock()
175169
mock_lm.simple_evaluate.return_value = mock_raw

0 commit comments

Comments
 (0)