-
Notifications
You must be signed in to change notification settings - Fork 1.4k
[New Feature] MOPD #9035
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[New Feature] MOPD #9035
Changes from all commits
e9e0f38
a542b7a
54822f2
664a4e1
dc4be10
a1fcda7
8923a54
0bd9b55
7c70a3a
d8e95c9
5667586
a227202
4f75226
e1e58af
b53a0f3
1d2cd01
479ddf7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| NPROC_PER_NODE=7 \ | ||
| PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ | ||
| CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 \ | ||
| swift rlhf \ | ||
| --rlhf_type gkd \ | ||
| --model Qwen/Qwen3-8B-Base \ | ||
| --teacher_model_group Qwen/Qwen3-32B AI-ModelScope/Skywork-Reward-Llama-3.1-8B-v0.2 \ | ||
| --use_mopd true \ | ||
| --tuner_type full \ | ||
| --dataset open-thoughts/OpenThoughts3-1.2M#10000 \ | ||
| --seq_kd false \ | ||
| --lmbda 1 \ | ||
| --beta 1 \ | ||
| --torch_dtype bfloat16 \ | ||
| --num_train_epochs 1 \ | ||
| --per_device_train_batch_size 1 \ | ||
| --learning_rate 1e-5 \ | ||
| --gradient_accumulation_steps 1 \ | ||
| --save_steps 1000 \ | ||
| --save_total_limit 2 \ | ||
| --logging_steps 1 \ | ||
| --max_length 16000 \ | ||
| --max_completion_length 8192 \ | ||
| --output_dir output \ | ||
| --warmup_ratio 0.05 \ | ||
| --save_only_model true \ | ||
| --dataloader_num_workers 64 \ | ||
| --dataset_num_proc 4 \ | ||
| --deepspeed zero3 \ | ||
| --teacher_deepspeed zero3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,6 +65,9 @@ class TeacherModelArguments: | |
| remotely. When this is set, `teacher_model` is not required. Defaults to None. | ||
| """ | ||
| teacher_model: Optional[str] = None | ||
| teacher_model_group: List[str] = field(default_factory=list) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider changing teacher_model to Optional[List[str]] (similar to reward_model) to avoid introducing additional parameters
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need for an extra use_mopd parameter, MOPD can be determined by the number of teacher models |
||
| # todo 还需要增加mopd_config | ||
| use_mopd: bool = False | ||
| teacher_adapters: List[str] = field(default_factory=list) | ||
| teacher_model_type: Optional[str] = field( | ||
| default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'}) | ||
|
|
@@ -73,15 +76,15 @@ class TeacherModelArguments: | |
| default=None, | ||
| metadata={ | ||
| 'help': | ||
| 'DeepSpeed configuration for teacher model. ' | ||
| 'Can be a path to a json file or one of: zero0, zero1, zero2, zero3, zero2_offload, zero3_offload' | ||
| 'DeepSpeed configuration for teacher model. ' | ||
| 'Can be a path to a json file or one of: zero0, zero1, zero2, zero3, zero2_offload, zero3_offload' | ||
| }) | ||
| teacher_model_server: Optional[str] = field( | ||
| default=None, | ||
| metadata={ | ||
| 'help': | ||
| 'URL of the teacher model server (e.g., http://localhost:8000). ' | ||
| 'When set, teacher logprobs are fetched via API instead of loading a local model.' | ||
| 'URL of the teacher model server (e.g., http://localhost:8000). ' | ||
| 'When set, teacher logprobs are fetched via API instead of loading a local model.' | ||
| }) | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -132,6 +132,24 @@ def _prepare_model_tokenizer(self): | |
| model, _ = result | ||
| setattr(self, f'{key}_model', model) | ||
|
|
||
| # Handle teacher_model_group for GKD | ||
| self.teacher_model_group_models = None | ||
| if args.rlhf_type == 'gkd' and hasattr(args, 'teacher_model_group') and args.teacher_model_group: | ||
| logger.info(f'Loading teacher_model_group with {len(args.teacher_model_group)} models') | ||
| self.teacher_model_group_models = [] | ||
| for idx, teacher_model_path in enumerate(args.teacher_model_group): | ||
| logger.info(f'Loading teacher model group [{idx}]: {teacher_model_path}') | ||
| # Use teacher_model_type and teacher_model_revision if available, otherwise infer | ||
| model_type = getattr(args, 'teacher_model_type', None) | ||
| model_revision = getattr(args, 'teacher_model_revision', None) | ||
|
|
||
| result = self._prepare_single_model_for_teacher_group(teacher_model_path, model_type, model_revision) | ||
| if result is not None: | ||
| model, _ = result | ||
| self.teacher_model_group_models.append(model) | ||
| logger.info(f'Successfully loaded teacher model group [{idx}]: {model}') | ||
| logger.info(f'Total teacher_model_group_models loaded: {len(self.teacher_model_group_models)}') | ||
|
Comment on lines
+136
to
+151
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly, just like the handling of reward_model, without introducing additional logic |
||
|
|
||
| # Handle reward model(s) | ||
| self.reward_model = None | ||
| if hasattr(args, 'reward_model') and args.reward_model is not None: | ||
|
|
@@ -166,6 +184,44 @@ def _prepare_model_tokenizer(self): | |
|
|
||
| super()._prepare_model_tokenizer() | ||
|
|
||
| def _prepare_single_model_for_teacher_group(self, model_id_or_path, model_type, model_revision): | ||
| """Prepare a single model for teacher_model_group.""" | ||
| args = self.args | ||
|
|
||
| if model_type is None: | ||
| model_info, _ = get_model_info_meta(model_id_or_path) | ||
| model_type = model_info.model_type | ||
|
|
||
| model_dir = safe_snapshot_download( | ||
| model_id_or_path=model_id_or_path, | ||
| revision=model_revision, | ||
| download_model=False, | ||
| use_hf=args.use_hf, | ||
| hub_token=args.hub_token, | ||
| ) | ||
| task_type, num_labels = self._get_model_task_type(model_dir) | ||
|
|
||
| context = nullcontext() | ||
| if args.teacher_deepspeed: | ||
| if args.teacher_deepspeed.get('zero_optimization', {}).get('stage') != 3: | ||
| context = disable_deepspeed_zero3() | ||
| with context: | ||
| model, processor = args.get_model_processor( | ||
| model=model_id_or_path, | ||
| model_type=model_type, | ||
| revision=model_revision, | ||
| task_type=task_type, | ||
| num_labels=num_labels) | ||
|
|
||
| # For teacher models, set to eval mode and disable gradients | ||
| if self.args.sequence_parallel_size > 1: | ||
| sequence_parallel.prepare( | ||
| self.args.sequence_parallel_size, model, processor, padding_free=args.padding_free) | ||
| model.requires_grad_(False).eval() | ||
|
|
||
| HfConfigFactory.set_config_attr(model.config, 'use_cache', False) | ||
| return model, processor | ||
|
Comment on lines
+187
to
+223
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above |
||
|
|
||
| @classmethod | ||
| def prepare_model(cls, args, model, *, template=None, train_dataset=None, task_type=None): | ||
| model = super().prepare_model(args, model, template=template, train_dataset=train_dataset, task_type=task_type) | ||
|
|
@@ -238,7 +294,16 @@ def _get_trainer_kwargs(self): | |
| trainer_kwargs['gkd_logits_topk'] = self.args.gkd_logits_topk | ||
| if self.args.teacher_model_server: | ||
| trainer_kwargs['teacher_model_server'] = self.args.teacher_model_server | ||
| # Pass pre-loaded teacher_model_group_models if available, otherwise pass the string list | ||
| if hasattr(self, 'teacher_model_group_models') and self.teacher_model_group_models: | ||
| trainer_kwargs['teacher_model_group_models'] = self.teacher_model_group_models | ||
| else: | ||
| trainer_kwargs['teacher_model_group'] = self.args.teacher_model_group | ||
| trainer_kwargs['teacher_use_disable_adapter'] = getattr(self.args, '_teacher_use_disable_adapter', False) | ||
| if self.args.use_mopd: | ||
| trainer_kwargs['use_mopd'] = self.args.use_mopd | ||
| # todo | ||
| # trainer_kwargs['mopd_config'] = self.args.mopd_config | ||
|
Comment on lines
+297
to
+306
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above |
||
| return trainer_kwargs | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_mopd标志在GKDTrainer中被引用,但未在参数定义中声明。应在此处添加以避免AttributeError。此外,建议更新TeacherModelArguments的 docstring 以包含teacher_model_group和use_mopd的说明。