Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions examples/train/rlhf/mopd/mopd.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
NPROC_PER_NODE=7 \
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 \
swift rlhf \
--rlhf_type gkd \
--model Qwen/Qwen3-8B-Base \
--teacher_model_group Qwen/Qwen3-32B AI-ModelScope/Skywork-Reward-Llama-3.1-8B-v0.2 \
--use_mopd true \
--tuner_type full \
--dataset open-thoughts/OpenThoughts3-1.2M#10000 \
--seq_kd false \
--lmbda 1 \
--beta 1 \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--learning_rate 1e-5 \
--gradient_accumulation_steps 1 \
--save_steps 1000 \
--save_total_limit 2 \
--logging_steps 1 \
--max_length 16000 \
--max_completion_length 8192 \
--output_dir output \
--warmup_ratio 0.05 \
--save_only_model true \
--dataloader_num_workers 64 \
--dataset_num_proc 4 \
--deepspeed zero3 \
--teacher_deepspeed zero3
11 changes: 7 additions & 4 deletions swift/arguments/rlhf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class TeacherModelArguments:
remotely. When this is set, `teacher_model` is not required. Defaults to None.
"""
teacher_model: Optional[str] = None
teacher_model_group: List[str] = field(default_factory=list)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

use_mopd 标志在 GKDTrainer 中被引用,但未在参数定义中声明。应在此处添加以避免 AttributeError。此外,建议更新 TeacherModelArguments 的 docstring 以包含 teacher_model_groupuse_mopd 的说明。

Suggested change
teacher_model_group: List[str] = field(default_factory=list)
teacher_model_group: List[str] = field(default_factory=list)
use_mopd: bool = False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider changing teacher_model to Optional[List[str]] (similar to reward_model) to avoid introducing additional parameters

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for an extra use_mopd parameter, MOPD can be determined by the number of teacher models

# todo 还需要增加mopd_config
use_mopd: bool = False
teacher_adapters: List[str] = field(default_factory=list)
teacher_model_type: Optional[str] = field(
default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'})
Expand All @@ -73,15 +76,15 @@ class TeacherModelArguments:
default=None,
metadata={
'help':
'DeepSpeed configuration for teacher model. '
'Can be a path to a json file or one of: zero0, zero1, zero2, zero3, zero2_offload, zero3_offload'
'DeepSpeed configuration for teacher model. '
'Can be a path to a json file or one of: zero0, zero1, zero2, zero3, zero2_offload, zero3_offload'
})
teacher_model_server: Optional[str] = field(
default=None,
metadata={
'help':
'URL of the teacher model server (e.g., http://localhost:8000). '
'When set, teacher logprobs are fetched via API instead of loading a local model.'
'URL of the teacher model server (e.g., http://localhost:8000). '
'When set, teacher logprobs are fetched via API instead of loading a local model.'
})


Expand Down
65 changes: 65 additions & 0 deletions swift/pipelines/train/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,24 @@ def _prepare_model_tokenizer(self):
model, _ = result
setattr(self, f'{key}_model', model)

# Handle teacher_model_group for GKD
self.teacher_model_group_models = None
if args.rlhf_type == 'gkd' and hasattr(args, 'teacher_model_group') and args.teacher_model_group:
logger.info(f'Loading teacher_model_group with {len(args.teacher_model_group)} models')
self.teacher_model_group_models = []
for idx, teacher_model_path in enumerate(args.teacher_model_group):
logger.info(f'Loading teacher model group [{idx}]: {teacher_model_path}')
# Use teacher_model_type and teacher_model_revision if available, otherwise infer
model_type = getattr(args, 'teacher_model_type', None)
model_revision = getattr(args, 'teacher_model_revision', None)

result = self._prepare_single_model_for_teacher_group(teacher_model_path, model_type, model_revision)
if result is not None:
model, _ = result
self.teacher_model_group_models.append(model)
logger.info(f'Successfully loaded teacher model group [{idx}]: {model}')
logger.info(f'Total teacher_model_group_models loaded: {len(self.teacher_model_group_models)}')
Comment on lines +136 to +151
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, just like the handling of reward_model, without introducing additional logic


# Handle reward model(s)
self.reward_model = None
if hasattr(args, 'reward_model') and args.reward_model is not None:
Expand Down Expand Up @@ -166,6 +184,44 @@ def _prepare_model_tokenizer(self):

super()._prepare_model_tokenizer()

def _prepare_single_model_for_teacher_group(self, model_id_or_path, model_type, model_revision):
"""Prepare a single model for teacher_model_group."""
args = self.args

if model_type is None:
model_info, _ = get_model_info_meta(model_id_or_path)
model_type = model_info.model_type

model_dir = safe_snapshot_download(
model_id_or_path=model_id_or_path,
revision=model_revision,
download_model=False,
use_hf=args.use_hf,
hub_token=args.hub_token,
)
task_type, num_labels = self._get_model_task_type(model_dir)

context = nullcontext()
if args.teacher_deepspeed:
if args.teacher_deepspeed.get('zero_optimization', {}).get('stage') != 3:
context = disable_deepspeed_zero3()
with context:
model, processor = args.get_model_processor(
model=model_id_or_path,
model_type=model_type,
revision=model_revision,
task_type=task_type,
num_labels=num_labels)

# For teacher models, set to eval mode and disable gradients
if self.args.sequence_parallel_size > 1:
sequence_parallel.prepare(
self.args.sequence_parallel_size, model, processor, padding_free=args.padding_free)
model.requires_grad_(False).eval()

HfConfigFactory.set_config_attr(model.config, 'use_cache', False)
return model, processor
Comment on lines +187 to +223
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above


@classmethod
def prepare_model(cls, args, model, *, template=None, train_dataset=None, task_type=None):
model = super().prepare_model(args, model, template=template, train_dataset=train_dataset, task_type=task_type)
Expand Down Expand Up @@ -238,7 +294,16 @@ def _get_trainer_kwargs(self):
trainer_kwargs['gkd_logits_topk'] = self.args.gkd_logits_topk
if self.args.teacher_model_server:
trainer_kwargs['teacher_model_server'] = self.args.teacher_model_server
# Pass pre-loaded teacher_model_group_models if available, otherwise pass the string list
if hasattr(self, 'teacher_model_group_models') and self.teacher_model_group_models:
trainer_kwargs['teacher_model_group_models'] = self.teacher_model_group_models
else:
trainer_kwargs['teacher_model_group'] = self.args.teacher_model_group
trainer_kwargs['teacher_use_disable_adapter'] = getattr(self.args, '_teacher_use_disable_adapter', False)
if self.args.use_mopd:
trainer_kwargs['use_mopd'] = self.args.use_mopd
# todo
# trainer_kwargs['mopd_config'] = self.args.mopd_config
Comment on lines +297 to +306
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

return trainer_kwargs


Expand Down
Loading