@@ -45,12 +45,12 @@ def test_train_sft_stage(self, tmp_path) -> None:
4545 args = argparse .Namespace (config = str (config_path ), stage = "sft" , push = None )
4646
4747 mock_runner = MagicMock ()
48- mock_runner .train .return_value = MagicMock (
49- output_dir = tmp_path , metrics = {"train_loss" : 0.5 }
50- )
48+ mock_runner .train .return_value = MagicMock (output_dir = tmp_path , metrics = {"train_loss" : 0.5 })
5149
52- with patch ("alignrl.sft.SFTRunner" , return_value = mock_runner ) as mock_cls , \
53- patch ("alignrl.sft.SFTConfig" ) as mock_cfg_cls :
50+ with (
51+ patch ("alignrl.sft.SFTRunner" , return_value = mock_runner ) as mock_cls ,
52+ patch ("alignrl.sft.SFTConfig" ) as mock_cfg_cls ,
53+ ):
5454 mock_cfg_cls .from_yaml .return_value = MagicMock ()
5555 cmd_train (args )
5656 mock_cls .assert_called_once ()
@@ -62,12 +62,12 @@ def test_train_grpo_stage(self, tmp_path) -> None:
6262 args = argparse .Namespace (config = str (config_path ), stage = "grpo" , push = None )
6363
6464 mock_runner = MagicMock ()
65- mock_runner .train .return_value = MagicMock (
66- output_dir = tmp_path , metrics = {"train_loss" : 0.3 }
67- )
65+ mock_runner .train .return_value = MagicMock (output_dir = tmp_path , metrics = {"train_loss" : 0.3 })
6866
69- with patch ("alignrl.grpo.GRPORunner" , return_value = mock_runner ), \
70- patch ("alignrl.grpo.GRPOConfig" ) as mock_cfg_cls :
67+ with (
68+ patch ("alignrl.grpo.GRPORunner" , return_value = mock_runner ),
69+ patch ("alignrl.grpo.GRPOConfig" ) as mock_cfg_cls ,
70+ ):
7171 mock_cfg_cls .from_yaml .return_value = MagicMock ()
7272 cmd_train (args )
7373 mock_runner .train .assert_called_once ()
@@ -78,12 +78,12 @@ def test_train_dpo_stage(self, tmp_path) -> None:
7878 args = argparse .Namespace (config = str (config_path ), stage = "dpo" , push = None )
7979
8080 mock_runner = MagicMock ()
81- mock_runner .train .return_value = MagicMock (
82- output_dir = tmp_path , metrics = {"train_loss" : 0.2 }
83- )
81+ mock_runner .train .return_value = MagicMock (output_dir = tmp_path , metrics = {"train_loss" : 0.2 })
8482
85- with patch ("alignrl.dpo.DPORunner" , return_value = mock_runner ), \
86- patch ("alignrl.dpo.DPOConfig" ) as mock_cfg_cls :
83+ with (
84+ patch ("alignrl.dpo.DPORunner" , return_value = mock_runner ),
85+ patch ("alignrl.dpo.DPOConfig" ) as mock_cfg_cls ,
86+ ):
8787 mock_cfg_cls .from_yaml .return_value = MagicMock ()
8888 cmd_train (args )
8989 mock_runner .train .assert_called_once ()
@@ -116,8 +116,10 @@ def test_eval_creates_output(self, tmp_path) -> None:
116116 wandb = False ,
117117 )
118118
119- with patch ("alignrl.eval.EvalRunner" , return_value = mock_runner ), \
120- patch ("alignrl.eval.EvalConfig" ):
119+ with (
120+ patch ("alignrl.eval.EvalRunner" , return_value = mock_runner ),
121+ patch ("alignrl.eval.EvalConfig" ),
122+ ):
121123 cmd_eval (args )
122124 mock_runner .evaluate .assert_called_once_with (stage = "base" )
123125 assert (tmp_path / "results" / "eval_base.json" ).exists ()
0 commit comments