From 54822f21b81945761800d25847c877c4a9c115e0 Mon Sep 17 00:00:00 2001 From: Yao Chuming <1416004356@qq.com> Date: Wed, 8 Apr 2026 09:44:44 +0800 Subject: [PATCH 01/13] [MOPD] init --- swift/arguments/rlhf_args.py | 1 + swift/rlhf_trainers/gkd_trainer.py | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/swift/arguments/rlhf_args.py b/swift/arguments/rlhf_args.py index f7e6288806..5b08153d21 100644 --- a/swift/arguments/rlhf_args.py +++ b/swift/arguments/rlhf_args.py @@ -65,6 +65,7 @@ class TeacherModelArguments: remotely. When this is set, `teacher_model` is not required. Defaults to None. """ teacher_model: Optional[str] = None + teacher_model_group: List[str] = field(default_factory=list) teacher_adapters: List[str] = field(default_factory=list) teacher_model_type: Optional[str] = field( default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'}) diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index b6ae231c69..d077cbccc0 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -298,6 +298,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N self.prepare_logits_to_keep(inputs) model_inputs['logits_to_keep'] = inputs['logits_to_keep'] + teacher_model = self.choose_teacher_model() if self.use_liger_gkd_loss: # Liger fused JSD loss for memory efficiency # Get base models (exclude lm_head to save memory) @@ -307,7 +308,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N base_student = getattr(unwrapped_student, getattr(unwrapped_student, 'base_model_prefix', 'model'), unwrapped_student) - unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) + unwrapped_teacher = self.accelerator.unwrap_model(teacher_model) base_teacher = getattr(unwrapped_teacher, getattr(unwrapped_teacher, 'base_model_prefix', 'model'), unwrapped_teacher) @@ -316,7 +317,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() with load_context: - with torch.no_grad(), disable_gradient_checkpointing(self.teacher_model, + with torch.no_grad(), disable_gradient_checkpointing(teacher_model, self.args.gradient_checkpointing_kwargs): teacher_outputs = base_teacher(**model_inputs, use_cache=False) @@ -415,7 +416,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N loss = loss + self.args.sft_alpha * outputs_student.loss # Separate teacher model provided else: - assert self.teacher_model is not None + assert teacher_model is not None if self.args.sft_alpha > 0: model_inputs['labels'] = inputs['labels'] outputs_student = model(**model_inputs) @@ -426,9 +427,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N } load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() - with torch.no_grad(), load_context, disable_gradient_checkpointing(self.teacher_model, + with torch.no_grad(), load_context, disable_gradient_checkpointing(teacher_model, self.args.gradient_checkpointing_kwargs): - outputs_teacher = self.teacher_model(**t_fwd) + outputs_teacher = teacher_model(**t_fwd) opsd_labels = opsd_teacher_inputs.get('labels') if opsd_teacher_inputs is not None else None teacher_out = TeacherOutput(full_logits=outputs_teacher.logits, opsd_teacher_labels=opsd_labels) @@ -443,6 +444,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N else: return loss + def choose_teacher_model(self): + if not self.args.use_mopd: + return self.teacher_model + #todo 使用mopd时从教师模型组选择最佳模型 + def _prepare_batch_inputs(self, inputs: list, encode_prompt_only: bool = False) -> Dict[str, torch.Tensor]: """Prepare batch inputs for training. @@ -788,6 +794,7 @@ def generalized_jsd_loss( t_log_probs = F.log_softmax(t_chunk, dim=-1) del s_chunk, t_chunk + #todo 使用mopd的计算函数,增加教师模型权重 if beta == 0: jsd_chunk = F.kl_div(s_log_probs, t_log_probs, reduction='none', log_target=True) elif beta == 1: From 664a4e1513984ae4ed7dc435f85ed343a5f64877 Mon Sep 17 00:00:00 2001 From: Yao Chuming <1416004356@qq.com> Date: Tue, 12 May 2026 10:46:45 +0800 Subject: [PATCH 02/13] [MOPD] support mult teacher model --- swift/pipelines/train/rlhf.py | 61 +++ swift/rlhf_trainers/GOLDLossAdapter.py | 531 +++++++++++++++++++++++ swift/rlhf_trainers/gkd_trainer.py | 562 ++++++++++++++++++------- 3 files changed, 1010 insertions(+), 144 deletions(-) create mode 100644 swift/rlhf_trainers/GOLDLossAdapter.py diff --git a/swift/pipelines/train/rlhf.py b/swift/pipelines/train/rlhf.py index fd9c16b2ea..6935e34f31 100644 --- a/swift/pipelines/train/rlhf.py +++ b/swift/pipelines/train/rlhf.py @@ -132,6 +132,24 @@ def _prepare_model_tokenizer(self): model, _ = result setattr(self, f'{key}_model', model) + # Handle teacher_model_group for GKD + self.teacher_model_group_models = None + if args.rlhf_type == 'gkd' and hasattr(args, 'teacher_model_group') and args.teacher_model_group: + logger.info(f'Loading teacher_model_group with {len(args.teacher_model_group)} models') + self.teacher_model_group_models = [] + for idx, teacher_model_path in enumerate(args.teacher_model_group): + logger.info(f'Loading teacher model group [{idx}]: {teacher_model_path}') + # Use teacher_model_type and teacher_model_revision if available, otherwise infer + model_type = getattr(args, 'teacher_model_type', None) + model_revision = getattr(args, 'teacher_model_revision', None) + + result = self._prepare_single_model_for_teacher_group(teacher_model_path, model_type, model_revision) + if result is not None: + model, _ = result + self.teacher_model_group_models.append(model) + logger.info(f'Successfully loaded teacher model group [{idx}]: {model}') + logger.info(f'Total teacher_model_group_models loaded: {len(self.teacher_model_group_models)}') + # Handle reward model(s) self.reward_model = None if hasattr(args, 'reward_model') and args.reward_model is not None: @@ -166,6 +184,44 @@ def _prepare_model_tokenizer(self): super()._prepare_model_tokenizer() + def _prepare_single_model_for_teacher_group(self, model_id_or_path, model_type, model_revision): + """Prepare a single model for teacher_model_group.""" + args = self.args + + if model_type is None: + model_info, _ = get_model_info_meta(model_id_or_path) + model_type = model_info.model_type + + model_dir = safe_snapshot_download( + model_id_or_path=model_id_or_path, + revision=model_revision, + download_model=False, + use_hf=args.use_hf, + hub_token=args.hub_token, + ) + task_type, num_labels = self._get_model_task_type(model_dir) + + context = nullcontext() + if args.teacher_deepspeed: + if args.teacher_deepspeed.get('zero_optimization', {}).get('stage') != 3: + context = disable_deepspeed_zero3() + with context: + model, processor = args.get_model_processor( + model=model_id_or_path, + model_type=model_type, + revision=model_revision, + task_type=task_type, + num_labels=num_labels) + + # For teacher models, set to eval mode and disable gradients + if self.args.sequence_parallel_size > 1: + sequence_parallel.prepare( + self.args.sequence_parallel_size, model, processor, padding_free=args.padding_free) + model.requires_grad_(False).eval() + + HfConfigFactory.set_config_attr(model.config, 'use_cache', False) + return model, processor + @classmethod def prepare_model(cls, args, model, *, template=None, train_dataset=None, task_type=None): model = super().prepare_model(args, model, template=template, train_dataset=train_dataset, task_type=task_type) @@ -238,6 +294,11 @@ def _get_trainer_kwargs(self): trainer_kwargs['gkd_logits_topk'] = self.args.gkd_logits_topk if self.args.teacher_model_server: trainer_kwargs['teacher_model_server'] = self.args.teacher_model_server + # Pass pre-loaded teacher_model_group_models if available, otherwise pass the string list + if hasattr(self, 'teacher_model_group_models') and self.teacher_model_group_models: + trainer_kwargs['teacher_model_group_models'] = self.teacher_model_group_models + else: + trainer_kwargs['teacher_model_group'] = self.args.teacher_model_group trainer_kwargs['teacher_use_disable_adapter'] = getattr(self.args, '_teacher_use_disable_adapter', False) return trainer_kwargs diff --git a/swift/rlhf_trainers/GOLDLossAdapter.py b/swift/rlhf_trainers/GOLDLossAdapter.py new file mode 100644 index 0000000000..37ec431b8f --- /dev/null +++ b/swift/rlhf_trainers/GOLDLossAdapter.py @@ -0,0 +1,531 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Tuple, List +from transformers import PreTrainedTokenizerBase + + +class GOLDLossAdapter(nn.Module): + """ + - GOLD (General Online Logit Distillation) 损失函数适配器 + 支持: + 1. ULD损失 (Universal Logit Distillation) + 2. 扩展ULD (跨tokenizer对齐) + 3. 混合损失 (Hybrid ULD + JSD) + + 使用示例: + adapter = GOLDLossAdapter( + config={ + 'use_uld_loss': True, + 'use_extended_uld': True, + 'uld_use_hybrid_loss': False, + 'uld_crossentropy_weight': 0.0, + 'uld_distillation_weight': 1.0, + 'uld_student_temperature': 1.0, + 'uld_teacher_temperature': 1.0, + }, + student_tokenizer=student_tok, + teacher_tokenizer=teacher_tok, + ) + + loss = adapter( + student_logits=student_outputs.logits, + teacher_logits=teacher_outputs.logits, + student_labels=student_labels, + teacher_labels=teacher_labels, + student_input_ids=student_input_ids, + teacher_input_ids=teacher_input_ids, + ) + """ + + def __init__( + self, + config: dict, + student_tokenizer: Optional[PreTrainedTokenizerBase] = None, + teacher_tokenizer: Optional[PreTrainedTokenizerBase] = None, + device: Optional[torch.device] = None, + ): + super().__init__() + self.device = device + + # 基础配置 + self.use_uld_loss = config.get('use_uld_loss', True) # 是否开启通用蒸馏 + self.crossentropy_weight = config.get('uld_crossentropy_weight', 0.0) + self.distillation_weight = config.get('uld_distillation_weight', 1.0) + self.student_temperature = config.get('uld_student_temperature', 0.9) + self.teacher_temperature = config.get('uld_teacher_temperature', 0.9) + self.skip_student_eos = config.get('uld_skip_student_eos', True) + self.skip_teacher_eos = config.get('uld_skip_teacher_eos', True) + self.use_extended_uld = config.get('use_extended_uld', True) + self.ignore_index = -100 + + # Tokenizers + self.student_tokenizer = student_tokenizer + self.teacher_tokenizer = teacher_tokenizer + + # Hybrid ULD配置 + self.use_hybrid_loss = config.get('uld_use_hybrid_loss', True) # 是否对完全匹配的词汇进行匹配,开启提高稳定性 + self.hybrid_matched_weight = config.get('uld_hybrid_matched_weight', None) + self.hybrid_unmatched_weight = config.get('uld_hybrid_unmatched_weight', None) + self.beta = config.get('beta', 1.0) + + # 初始化词汇映射(用于hybrid loss) + self._vocab_mapping = None + self._teacher_matched_ids = None + self._student_matched_ids = None + self.mapping_tensor = None + + if self.use_hybrid_loss and student_tokenizer and teacher_tokenizer: + self._initialize_vocabulary_mapping() + + # 用于logging + self.last_matched_loss = None + self.last_unmatched_loss = None + + def _initialize_vocabulary_mapping(self): + """初始化学生-教师tokenizer的词汇映射""" + student_vocab = self.student_tokenizer.get_vocab() + teacher_vocab = self.teacher_tokenizer.get_vocab() + + student_token_to_id = dict(student_vocab.items()) + + vocab_mapping = {} + teacher_matched_ids = set() + student_matched_ids = set() + + for token_str, teacher_id in teacher_vocab.items(): + if token_str in student_token_to_id: + student_id = student_token_to_id[token_str] + vocab_mapping[teacher_id] = student_id + teacher_matched_ids.add(teacher_id) + student_matched_ids.add(student_id) + + self._vocab_mapping = vocab_mapping + self._teacher_matched_ids = teacher_matched_ids + self._student_matched_ids = student_matched_ids + + if self._vocab_mapping: + max_matched_teacher_id = max(self._vocab_mapping.keys()) + self.mapping_tensor = torch.full( + (max_matched_teacher_id + 1,), -1, dtype=torch.long + ) + for k, v in self._vocab_mapping.items(): + self.mapping_tensor[k] = v + if self.device: + self.mapping_tensor = self.mapping_tensor.to(self.device) + + def forward( + self, + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + student_labels: torch.Tensor, + teacher_labels: torch.Tensor, + student_input_ids: torch.Tensor, + teacher_input_ids: torch.Tensor, + ) -> torch.Tensor: + """ + 计算GOLD/ULD损失 + + Args: + student_logits: [batch_size, seq_len, student_vocab_size] + teacher_logits: [batch_size, seq_len, teacher_vocab_size] + student_labels: [batch_size, seq_len], -100表示忽略 + teacher_labels: [batch_size, seq_len], -100表示忽略 + student_input_ids: [batch_size, seq_len] + teacher_input_ids: [batch_size, seq_len] + + Returns: + loss: scalar tensor + """ + + if not self.use_uld_loss: + return torch.tensor(0.0, device=student_logits.device, requires_grad=True) + + # 1. Cross-entropy loss (可选) + crossentropy_loss = self._compute_cross_entropy(student_logits, student_labels) + + # 2. Distillation loss (ULD) + distillation_loss = self._compute_distillation_loss( + student_logits, teacher_logits, + student_labels, teacher_labels, + student_input_ids, teacher_input_ids + ) + return crossentropy_loss + distillation_loss + + def _compute_cross_entropy( + self, + student_logits: torch.Tensor, + student_labels: torch.Tensor + ) -> torch.Tensor: + """计算cross-entropy loss""" + if self.crossentropy_weight <= 0: + return torch.tensor(0.0, device=student_logits.device, requires_grad=True) + + shift_logits = student_logits[..., :-1, :].contiguous() + shift_labels = student_labels[..., 1:].contiguous() + + loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + ce_loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1) + ) + return self.crossentropy_weight * ce_loss + + def _compute_distillation_loss( + self, + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + student_labels: torch.Tensor, + teacher_labels: torch.Tensor, + student_input_ids: torch.Tensor, + teacher_input_ids: torch.Tensor, + ) -> torch.Tensor: + """计算ULD蒸馏损失""" + # 获取答案区域 + student_answer_idx, student_answer_size = self._get_answer_regions(student_labels) + teacher_answer_idx, teacher_answer_size = self._get_answer_regions(teacher_labels) + + if self.skip_student_eos: + student_answer_size = [s - 1 for s in student_answer_size] + if self.skip_teacher_eos: + teacher_answer_size = [t - 1 for t in teacher_answer_size] + + # 边界检查 + if not student_answer_size or not teacher_answer_size: + return torch.zeros(1, device=student_logits.device, requires_grad=True) * 1e-8 + + batch_size = student_logits.size(0) + distillation_losses = [] + + for i in range(batch_size): + s_start = student_answer_idx[i] + s_size = student_answer_size[i] + t_start = teacher_answer_idx[i] + t_size = teacher_answer_size[i] + + if s_size <= 0 or t_size <= 0: + loss_i = student_logits[i].sum() * 0.0 + # Ensure the loss tensor requires gradients + loss_i = loss_i.detach().requires_grad_(True) + distillation_losses.append(loss_i) + continue + + # 提取答案logits + student_ans_logits = student_logits[i, s_start:s_start + s_size] + teacher_ans_logits = teacher_logits[i, t_start:t_start + t_size] + + # 转换为概率 + student_probs = F.softmax(student_ans_logits / self.student_temperature, dim=-1) + teacher_probs = F.softmax(teacher_ans_logits / self.teacher_temperature, dim=-1) + + def decode_tokens(tokenizer, token_ids): + pieces = [] + prev = "" + for k in range(len(token_ids)): + cur = tokenizer.decode(token_ids[:k + 1], skip_special_tokens=False) + pieces.append(cur[len(prev):]) + prev = cur + return pieces + student_token_ids = student_input_ids[i, s_start:s_start + s_size].tolist() + teacher_token_ids = teacher_input_ids[i, t_start:t_start + t_size].tolist() + + # Token对齐 + if self.use_extended_uld: + student_groups, teacher_groups = self._build_alignment_groups_from_ids( + student_token_ids, teacher_token_ids + ) + + student_aligned = self._merge_probabilities_with_groups( + student_probs, student_groups, student_token_ids + ) + teacher_aligned = self._merge_probabilities_with_groups( + teacher_probs, teacher_groups, teacher_token_ids + ) + + else: + min_len = min(len(student_token_ids), len(teacher_token_ids)) + student_aligned = student_probs[:min_len] + teacher_aligned = teacher_probs[:min_len] + + # 计算损失 + if self.use_hybrid_loss and self._vocab_mapping: + aligned_loss = self._compute_hybrid_uld_loss(student_aligned, teacher_aligned) + else: + aligned_loss = self._compute_basic_uld_loss(student_aligned, teacher_aligned) + + distillation_losses.append(aligned_loss) + distillation_loss = torch.stack(distillation_losses).mean() + return self.distillation_weight * distillation_loss + + def _get_answer_regions(self, labels: torch.Tensor) -> Tuple[List[int], List[int]]: + """获取答案区域的起始位置和大小""" + indices = [] + sizes = [] + + for label in labels: + mask = label.ne(self.ignore_index) + if not mask.any(): + indices.append(0) + sizes.append(0) + continue + + valid_indices = mask.nonzero(as_tuple=True)[0] + indices.append(int(valid_indices[0].item())) + sizes.append(int(mask.sum().item())) + + return indices, sizes + + def _build_alignment_groups_from_ids( + self, + student_token_ids: List[int], + teacher_token_ids: List[int] + ) -> Tuple[List[List[int]], List[List[int]]]: + """ + 基于文本内容构建对齐组 + 使用贪心子串匹配算法 + """ + + def decode_tokens(tokenizer, token_ids): + pieces = [] + prev = "" + for k in range(len(token_ids)): + cur = tokenizer.decode(token_ids[:k + 1], skip_special_tokens=False) + pieces.append(cur[len(prev):]) + prev = cur + return pieces + + student_pieces = decode_tokens(self.student_tokenizer, student_token_ids) + teacher_pieces = decode_tokens(self.teacher_tokenizer, teacher_token_ids) + + # 贪心匹配算法 + student_groups = [] + teacher_groups = [] + s_idx = 0 + t_idx = 0 + + while s_idx < len(student_pieces) and t_idx < len(teacher_pieces): + student_text = "" + teacher_text = "" + student_group = [] + teacher_group = [] + + # 尝试找到最短的连续匹配序列 + while s_idx < len(student_pieces) and t_idx < len(teacher_pieces): + if not student_group: + student_group.append(s_idx) + student_text += student_pieces[s_idx] + s_idx += 1 + + if not teacher_group: + teacher_group.append(t_idx) + teacher_text += teacher_pieces[t_idx] + t_idx += 1 + + # 检查是否匹配 + if student_text == teacher_text: + student_groups.append(student_group) + teacher_groups.append(teacher_group) + break + elif len(student_text) < len(teacher_text): + if s_idx < len(student_pieces): + student_group.append(s_idx) + student_text += student_pieces[s_idx] + s_idx += 1 + else: + break + else: + if t_idx < len(teacher_pieces): + teacher_group.append(t_idx) + teacher_text += teacher_pieces[t_idx] + t_idx += 1 + else: + break + else: + # 未完全匹配,添加剩余部分 + if student_group and teacher_group: + student_groups.append(student_group) + teacher_groups.append(teacher_group) + + return student_groups, teacher_groups + + def _merge_probabilities_with_groups( + self, + probs: torch.Tensor, + alignment_groups: List[List[int]], + token_ids: List[int], + ) -> torch.Tensor: + """ + 根据对齐组合并概率分布 + 使用链式法则: P_merged = P(y|x_0) * P(x_1|x_0) * P(x_2|x_0,x_1) * ... + """ + aligned_probs = [] + + for group in alignment_groups: + if len(group) > 1: + # 第一个token的边际概率 + marginal_probs = probs[group[0]] # [vocab_size] + + # 后续token的条件概率(标量) + conditional_product = 1.0 + for k in range(1, len(group)): + cond_prob = probs[group[k], token_ids[group[k - 1]]] + conditional_product *= cond_prob + + merged_probs = marginal_probs * conditional_product + aligned_probs.append(merged_probs) + elif len(group) == 1: + aligned_probs.append(probs[group[0]]) + + if aligned_probs: + return torch.stack(aligned_probs) + else: + # 返回一个空的但需要梯度的张量 + empty_tensor = probs[:0].detach().requires_grad_(True) + return empty_tensor + + def _compute_basic_uld_loss( + self, + student_aligned: torch.Tensor, + teacher_aligned: torch.Tensor, + ) -> torch.Tensor: + """基础ULD损失:排序后的L1距离""" + student_sorted = student_aligned.sort(dim=-1, descending=True).values + teacher_sorted = teacher_aligned.sort(dim=-1, descending=True).values + + # Padding到相同vocab size + s_vocab = student_sorted.size(-1) + t_vocab = teacher_sorted.size(-1) + max_vocab = max(s_vocab, t_vocab) + + if s_vocab < max_vocab: + student_sorted = F.pad(student_sorted, (0, max_vocab - s_vocab)) + if t_vocab < max_vocab: + teacher_sorted = F.pad(teacher_sorted, (0, max_vocab - t_vocab)) + + loss = F.l1_loss(student_sorted, teacher_sorted, reduction="sum") + loss /= student_aligned.size(0) + + return loss + + def _compute_hybrid_uld_loss( + self, + student_aligned: torch.Tensor, + teacher_aligned: torch.Tensor, + ) -> torch.Tensor: + """混合ULD损失:matched用JSD,unmatched用排序L1""" + device = student_aligned.device + s_vocab = student_aligned.size(-1) + t_vocab = teacher_aligned.size(-1) + + # 创建matched/unmatched masks + if self._teacher_matched_ids: + teacher_matched_idx = torch.tensor( + sorted(self._teacher_matched_ids), dtype=torch.long, device=device + ) + student_matched_idx = self.mapping_tensor[teacher_matched_idx] + else: + teacher_matched_idx = torch.tensor([], dtype=torch.long, device=device) + student_matched_idx = torch.tensor([], dtype=torch.long, device=device) + + teacher_matched_mask = torch.zeros(t_vocab, dtype=torch.bool, device=device) + student_matched_mask = torch.zeros(s_vocab, dtype=torch.bool, device=device) + + if len(teacher_matched_idx) > 0: + teacher_matched_mask[teacher_matched_idx] = True + student_matched_mask[student_matched_idx] = True + + # 1. Matched tokens的JSD损失 + matched_loss = torch.tensor(0.0, device=device, requires_grad=True) + matched_count = 0 + + if len(teacher_matched_idx) > 0: + teacher_matched_probs = teacher_aligned[:, teacher_matched_idx] + student_matched_probs = student_aligned[:, student_matched_idx] + matched_count = teacher_matched_probs.size(-1) + + matched_loss = self._compute_jsd_for_matched( + student_matched_probs, teacher_matched_probs + ) + # 2. Unmatched tokens的排序L1损失 + teacher_unmatched = teacher_aligned[:, ~teacher_matched_mask] + student_unmatched = student_aligned[:, ~student_matched_mask] + + unmatched_loss = torch.tensor(0.0, device=device, requires_grad=True) + if teacher_unmatched.size(-1) > 0 and student_unmatched.size(-1) > 0: + teacher_sorted = teacher_unmatched.sort(dim=-1, descending=True).values + student_sorted = student_unmatched.sort(dim=-1, descending=True).values + + t_size = teacher_sorted.size(-1) + s_size = student_sorted.size(-1) + max_size = max(t_size, s_size) + + if t_size < max_size: + teacher_sorted = F.pad(teacher_sorted, (0, max_size - t_size)) + if s_size < max_size: + student_sorted = F.pad(student_sorted, (0, max_size - s_size)) + + unmatched_loss = F.l1_loss(student_sorted, teacher_sorted, reduction="sum") + unmatched_loss /= student_aligned.size(0) + + # 3. 加权组合 + if self.hybrid_matched_weight is None: + w_matched = matched_count / max(1, t_vocab) + w_unmatched = 1.0 - w_matched + else: + w_matched = self.hybrid_matched_weight + w_unmatched = self.hybrid_unmatched_weight + + total_loss = w_matched * matched_loss + w_unmatched * unmatched_loss + + # 保存用于logging + self.last_matched_loss = matched_loss + self.last_unmatched_loss = unmatched_loss + + return total_loss + + def _compute_jsd_for_matched( + self, + student_probs: torch.Tensor, + teacher_probs: torch.Tensor, + epsilon: float = 1e-8 + ) -> torch.Tensor: + """计算matched tokens的JSD损失,添加数值稳定性处理""" + batch_seq_len, num_matched = student_probs.shape + + # 检查输入概率分布是否有效 + if torch.isnan(student_probs).any() or torch.isnan(teacher_probs).any(): + return torch.tensor(0.0, device=student_probs.device, requires_grad=True) + + # 添加epsilon防止数值下溢和log(0) + student_probs = student_probs.clamp(min=epsilon) + teacher_probs = teacher_probs.clamp(min=epsilon) + + # 重新归一化概率分布 + student_probs = student_probs / student_probs.sum(dim=-1, keepdim=True) + teacher_probs = teacher_probs / teacher_probs.sum(dim=-1, keepdim=True) + + student_flat = student_probs.view(-1, num_matched) + teacher_flat = teacher_probs.view(-1, num_matched) + + # JSD = 0.5 * KL(P||M) + 0.5 * KL(Q||M), where M = 0.5*(P+Q) + m = 0.5 * (student_flat + teacher_flat) + + # 添加epsilon到中间分布 + m = m.clamp(min=epsilon) + m = m / m.sum(dim=-1, keepdim=True) + + # 直接对概率分布取对数,添加epsilon防止数值问题 + log_m = torch.log(m + epsilon) + log_student = torch.log(student_flat + epsilon) + log_teacher = torch.log(teacher_flat + epsilon) + + # 使用log_target=True,传入log概率 + kl_p_m = F.kl_div(log_m, log_student, reduction='batchmean', log_target=True) + kl_q_m = F.kl_div(log_m, log_teacher, reduction='batchmean', log_target=True) + jsd = 0.5 * (kl_p_m + kl_q_m) + + # 检查结果是否有效 + if torch.isnan(jsd) or torch.isinf(jsd): + return torch.tensor(0.0, device=student_probs.device, requires_grad=True) + + return jsd \ No newline at end of file diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index d077cbccc0..c4990611dd 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -16,6 +16,9 @@ from transformers import PreTrainedModel from trl import SFTTrainer as HFSFTTrainer from typing import Dict, Optional, Union +from transformers import AutoTokenizer +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from trl.trainer.utils import pad from swift.template import TemplateInputs from swift.trainers import SwiftMixin, disable_gradient_checkpointing @@ -79,6 +82,7 @@ class GKDTrainer(RolloutTrainerMixin, SwiftMixin, HFGKDTrainer): def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, *_args, **kwargs): teacher_model = kwargs.pop('teacher_model', None) + self.teacher_model_group = kwargs.pop('teacher_model_group', None) teacher_deepspeed_config = kwargs.pop('teacher_deepspeed_config', None) self.vllm_client = kwargs.pop('vllm_client', None) self.gkd_logits_topk = kwargs.pop('gkd_logits_topk', None) @@ -128,8 +132,29 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non self.teacher_model.eval() if self.args.offload_teacher_model: self.offload_model(self.accelerator.unwrap_model(self.teacher_model)) + + # Initialize teacher model group (for MOPD) + if self.teacher_model_group is not None and len(self.teacher_model_group) > 0: + prepared_models = [] + for model_name in self.teacher_model_group: + if self.is_deepspeed_enabled: + if teacher_deepspeed_config is not None: + prepared_model = prepare_deepspeed( + model_name, self.accelerator, deepspeed_config=teacher_deepspeed_config, training_args=args) + else: + prepared_model = prepare_deepspeed(model_name, self.accelerator) + elif self.is_fsdp_enabled: + from .utils import prepare_fsdp + prepared_model = prepare_fsdp(model_name, self.accelerator) + else: + prepared_model = self.accelerator.prepare_model(model_name, evaluation_mode=True) + prepared_model.eval() + if self.args.offload_teacher_model: + self.offload_model(self.accelerator.unwrap_model(prepared_model)) + prepared_models.append(prepared_model) + self.teacher_model_group = prepared_models else: - self.teacher_model = None + self.teacher_model_group = None # Initialize rollout infrastructure for vLLM support self.prepare_rollout() @@ -274,7 +299,32 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token new_position_ids = new_attention_mask.cumsum(dim=1) - 1 new_position_ids[new_position_ids < 0] = 0 inputs['position_ids'] = new_position_ids - return generated_tokens, new_attention_mask, new_labels + # 返回解码后的文本encoded_inputs + batch_size = generated_tokens.shape[0] + prompt_length = prompt_input_ids.shape[1] + + completion_texts = [] + pad_token_id = self.processing_class.pad_token_id + eos_token_id = self.processing_class.eos_token_id + for i in range(batch_size): + # Decode completion + completion_ids = generated_tokens[i][prompt_length:].tolist() + # 截断到第一个 EOS 或 PAD token + cleaned_ids = [] + for token_id in completion_ids: + if token_id == eos_token_id or token_id == pad_token_id: + break + cleaned_ids.append(token_id) + + # 解码清理后的 token IDs + completion_text = self.template.safe_decode(cleaned_ids) + + # 进一步清理空白字符(可选) + completion_text = completion_text.strip() + + completion_texts.append(completion_text) + + return generated_tokens, new_attention_mask, new_labels, completion_texts @profiling_decorator def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): @@ -298,156 +348,380 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N self.prepare_logits_to_keep(inputs) model_inputs['logits_to_keep'] = inputs['logits_to_keep'] - teacher_model = self.choose_teacher_model() - if self.use_liger_gkd_loss: - # Liger fused JSD loss for memory efficiency - # Get base models (exclude lm_head to save memory) - unwrapped_student = self.accelerator.unwrap_model(model) - if is_peft_model(unwrapped_student): - unwrapped_student = unwrapped_student.base_model.model - base_student = getattr(unwrapped_student, getattr(unwrapped_student, 'base_model_prefix', 'model'), - unwrapped_student) - - unwrapped_teacher = self.accelerator.unwrap_model(teacher_model) - base_teacher = getattr(unwrapped_teacher, getattr(unwrapped_teacher, 'base_model_prefix', 'model'), - unwrapped_teacher) - - # Forward through base models - student_outputs = base_student(**model_inputs, use_cache=False) - - load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() - with load_context: - with torch.no_grad(), disable_gradient_checkpointing(teacher_model, - self.args.gradient_checkpointing_kwargs): - teacher_outputs = base_teacher(**model_inputs, use_cache=False) - - # Get hidden states (shifted) - student_hidden = student_outputs.last_hidden_state[:, :-1] - teacher_hidden = teacher_outputs.last_hidden_state[:, :-1] - - # Release full outputs to free memory - del student_outputs, teacher_outputs - - # Prepare labels (shifted) - labels_mask = inputs['labels'] != -100 - masked_input_ids = torch.where(labels_mask, inputs['input_ids'], - torch.full_like(inputs['input_ids'], -100)) - true_labels = masked_input_ids[:, 1:].contiguous() - - # Release intermediate tensors - del labels_mask, masked_input_ids - - # Get output heads - student_head = unwrapped_student.get_output_embeddings() - teacher_head = unwrapped_teacher.get_output_embeddings() - - # Prepare context managers for gathering parameters in zero3 - teacher_context = get_gather_if_zero3_context(self, is_zero3=self.is_teacher_ds3)(teacher_head.weight) - student_context = get_gather_if_zero3_context(self)(student_head.weight) - - with teacher_context, student_context: - # Compute liger fused JSD loss - loss = self.liger_jsd_loss( - student_input=student_hidden, - student_weight=student_head.weight, - teacher_input=teacher_hidden, - teacher_weight=teacher_head.weight, - true_labels=true_labels, - student_bias=getattr(student_head, 'bias', None), - teacher_bias=getattr(teacher_head, 'bias', None), - ) - # Release hidden states after loss computation - del student_hidden, teacher_hidden, true_labels - outputs_student = None - # Teacher API mode: top-k logprobs fetched from external teacher server - elif self.use_teacher_api: - assert teacher_api_logprobs is not None - if self.args.sft_alpha > 0: - model_inputs['labels'] = inputs['labels'] - outputs_student = model(**model_inputs) - - # teacher_api shape: [batch, seq_len-1, topk] - # Pad to [batch, seq_len, topk] so it aligns with student logits. - teacher_api_logprobs = F.pad(teacher_api_logprobs, (0, 0, 0, 1), value=float('-inf')) - teacher_api_indices = F.pad(teacher_api_indices, (0, 0, 0, 1), value=0) - logits_to_keep = inputs.get('logits_to_keep') - if logits_to_keep is not None: - if isinstance(logits_to_keep, torch.Tensor) and logits_to_keep.dtype == torch.bool: - teacher_api_logprobs = teacher_api_logprobs[:, logits_to_keep] - teacher_api_indices = teacher_api_indices[:, logits_to_keep] - else: - n = logits_to_keep.item() if isinstance(logits_to_keep, torch.Tensor) else int(logits_to_keep) - teacher_api_logprobs = teacher_api_logprobs[:, -n:] - teacher_api_indices = teacher_api_indices[:, -n:] - - opsd_labels = opsd_teacher_inputs.get('labels') if opsd_teacher_inputs is not None else None - teacher_out = TeacherOutput( - topk_logprobs=teacher_api_logprobs, - topk_indices=teacher_api_indices, - opsd_teacher_labels=opsd_labels, - ) - loss = self._compute_jsd_loss(outputs_student.logits, teacher_out, inputs['labels']) - - if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT: - loss = loss + self.args.sft_alpha * outputs_student.loss - # Self-distillation mode: student model doubles as teacher - elif self._is_self_distillation: - if self.args.sft_alpha > 0: - model_inputs['labels'] = inputs['labels'] - outputs_student = model(**model_inputs) - - t_fwd = teacher_fwd_inputs if teacher_fwd_inputs is not None else { - k: v - for k, v in model_inputs.items() if k != 'labels' - } + loss_total = 0.0 + if self.teacher_model_group is None: + # Use single teacher model + teacher_model_group = [self.teacher_model] + else: + teacher_model_group = self.teacher_model_group + for teacher_model in teacher_model_group: + if self.use_liger_gkd_loss: + # Liger fused JSD loss for memory efficiency + # Get base models (exclude lm_head to save memory) + unwrapped_student = self.accelerator.unwrap_model(model) + if is_peft_model(unwrapped_student): + unwrapped_student = unwrapped_student.base_model.model + base_student = getattr(unwrapped_student, getattr(unwrapped_student, 'base_model_prefix', 'model'), + unwrapped_student) + + unwrapped_teacher = self.accelerator.unwrap_model(teacher_model) + base_teacher = getattr(unwrapped_teacher, getattr(unwrapped_teacher, 'base_model_prefix', 'model'), + unwrapped_teacher) + + # Forward through base models + student_outputs = base_student(**model_inputs, use_cache=False) - adapter_ctx = ( - self.accelerator.unwrap_model(model).disable_adapter() - if self._teacher_use_disable_adapter else nullcontext()) - with torch.no_grad(), adapter_ctx, \ - disable_gradient_checkpointing(model, self.args.gradient_checkpointing_kwargs): - outputs_teacher = model(**t_fwd) + load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() + with load_context: + with torch.no_grad(), disable_gradient_checkpointing(teacher_model, + self.args.gradient_checkpointing_kwargs): + teacher_outputs = base_teacher(**model_inputs, use_cache=False) + + # Get hidden states (shifted) + student_hidden = student_outputs.last_hidden_state[:, :-1] + teacher_hidden = teacher_outputs.last_hidden_state[:, :-1] + + # Release full outputs to free memory + del student_outputs, teacher_outputs + + # Prepare labels (shifted) + labels_mask = inputs['labels'] != -100 + masked_input_ids = torch.where(labels_mask, inputs['input_ids'], + torch.full_like(inputs['input_ids'], -100)) + true_labels = masked_input_ids[:, 1:].contiguous() + + # Release intermediate tensors + del labels_mask, masked_input_ids + + # Get output heads + student_head = unwrapped_student.get_output_embeddings() + teacher_head = unwrapped_teacher.get_output_embeddings() + + # Prepare context managers for gathering parameters in zero3 + teacher_context = get_gather_if_zero3_context(self, is_zero3=self.is_teacher_ds3)(teacher_head.weight) + student_context = get_gather_if_zero3_context(self)(student_head.weight) + + with teacher_context, student_context: + # Compute liger fused JSD loss + loss = self.liger_jsd_loss( + student_input=student_hidden, + student_weight=student_head.weight, + teacher_input=teacher_hidden, + teacher_weight=teacher_head.weight, + true_labels=true_labels, + student_bias=getattr(student_head, 'bias', None), + teacher_bias=getattr(teacher_head, 'bias', None), + ) + # Release hidden states after loss computation + del student_hidden, teacher_hidden, true_labels + outputs_student = None + # Teacher API mode: top-k logprobs fetched from external teacher server + elif self.use_teacher_api: + assert teacher_api_logprobs is not None + if self.args.sft_alpha > 0: + model_inputs['labels'] = inputs['labels'] + outputs_student = model(**model_inputs) + + # teacher_api shape: [batch, seq_len-1, topk] + # Pad to [batch, seq_len, topk] so it aligns with student logits. + teacher_api_logprobs = F.pad(teacher_api_logprobs, (0, 0, 0, 1), value=float('-inf')) + teacher_api_indices = F.pad(teacher_api_indices, (0, 0, 0, 1), value=0) + logits_to_keep = inputs.get('logits_to_keep') + if logits_to_keep is not None: + if isinstance(logits_to_keep, torch.Tensor) and logits_to_keep.dtype == torch.bool: + teacher_api_logprobs = teacher_api_logprobs[:, logits_to_keep] + teacher_api_indices = teacher_api_indices[:, logits_to_keep] + else: + n = logits_to_keep.item() if isinstance(logits_to_keep, torch.Tensor) else int(logits_to_keep) + teacher_api_logprobs = teacher_api_logprobs[:, -n:] + teacher_api_indices = teacher_api_indices[:, -n:] + + opsd_labels = opsd_teacher_inputs.get('labels') if opsd_teacher_inputs is not None else None + teacher_out = TeacherOutput( + topk_logprobs=teacher_api_logprobs, + topk_indices=teacher_api_indices, + opsd_teacher_labels=opsd_labels, + ) + loss = self._compute_jsd_loss(outputs_student.logits, teacher_out, inputs['labels']) + + if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT: + loss = loss + self.args.sft_alpha * outputs_student.loss + # Self-distillation mode: student model doubles as teacher + elif self._is_self_distillation: + if self.args.sft_alpha > 0: + model_inputs['labels'] = inputs['labels'] + outputs_student = model(**model_inputs) + + t_fwd = teacher_fwd_inputs if teacher_fwd_inputs is not None else { + k: v + for k, v in model_inputs.items() if k != 'labels' + } + + adapter_ctx = ( + self.accelerator.unwrap_model(model).disable_adapter() + if self._teacher_use_disable_adapter else nullcontext()) + with torch.no_grad(), adapter_ctx, \ + disable_gradient_checkpointing(model, self.args.gradient_checkpointing_kwargs): + outputs_teacher = model(**t_fwd) + + opsd_labels = opsd_teacher_inputs.get('labels') if opsd_teacher_inputs is not None else None + teacher_out = TeacherOutput(full_logits=outputs_teacher.logits, opsd_teacher_labels=opsd_labels) + loss = self._compute_jsd_loss(outputs_student.logits, teacher_out, inputs['labels']) + + if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT: + loss = loss + self.args.sft_alpha * outputs_student.loss + # Separate teacher model provided + else: + if not hasattr(self, 'student_tokenizer'): + student_model = model + student_model_path = getattr(student_model, 'name_or_path', None) + if student_model_path is None: + # Try to get path from config + if hasattr(student_model, 'config') and hasattr(student_model.config, '_name_or_path'): + student_model_path = student_model.config._name_or_path + + if student_model_path is None: + # If still None, try to get from model's base model + unwrapped_student = self.accelerator.unwrap_model(student_model) + if hasattr(unwrapped_student, 'base_model_prefix'): + base_model = getattr(unwrapped_student, unwrapped_student.base_model_prefix, unwrapped_student) + if hasattr(base_model, 'config') and hasattr(base_model.config, '_name_or_path'): + student_model_path = base_model.config._name_or_path + # Additional fallback: try to get from model's config name_or_path attribute + if student_model_path is None: + if hasattr(student_model, 'config') and hasattr(student_model.config, 'name_or_path'): + student_model_path = student_model.config.name_or_path + + # Additional fallback: try to get from unwrapped model's name_or_path + if student_model_path is None: + unwrapped_student = self.accelerator.unwrap_model(student_model) + student_model_path = getattr(unwrapped_student, 'name_or_path', None) + + # Additional fallback: try to get from unwrapped model's config name_or_path + if student_model_path is None: + unwrapped_student = self.accelerator.unwrap_model(student_model) + if hasattr(unwrapped_student, 'config') and hasattr(unwrapped_student.config, 'name_or_path'): + student_model_path = unwrapped_student.config.name_or_path + + if student_model_path is None: + # Provide detailed error information for debugging + model_info = f"Model type: {type(student_model)}, " + if hasattr(student_model, 'config'): + model_info += f"Config type: {type(student_model.config)}, " + model_info += f"Config attributes: {[attr for attr in dir(student_model.config) if not attr.startswith('_')]}" + else: + model_info += "No config available" + self.student_tokenizer = AutoTokenizer.from_pretrained(student_model_path) + # Initialize teacher tokenizer (only once) + if not hasattr(self, 'teacher_tokenizer'): + teacher_model_path = getattr(teacher_model, 'name_or_path', None) + if teacher_model_path is None: + # Try to get path from config + if hasattr(teacher_model, 'config') and hasattr(teacher_model.config, '_name_or_path'): + teacher_model_path = teacher_model.config._name_or_path + + if teacher_model_path is None: + # If still None, try to get from model's base model + unwrapped_teacher = self.accelerator.unwrap_model(teacher_model) + if hasattr(unwrapped_teacher, 'base_model_prefix'): + base_model = getattr(unwrapped_teacher, unwrapped_teacher.base_model_prefix, + unwrapped_teacher) + if hasattr(base_model, 'config') and hasattr(base_model.config, '_name_or_path'): + teacher_model_path = base_model.config._name_or_path + # Additional fallback: try to get from model's config name_or_path attribute + if teacher_model_path is None: + if hasattr(teacher_model, 'config') and hasattr(teacher_model.config, 'name_or_path'): + teacher_model_path = teacher_model.config.name_or_path + + # Additional fallback: try to get from unwrapped model's name_or_path + if teacher_model_path is None: + unwrapped_teacher = self.accelerator.unwrap_model(teacher_model) + teacher_model_path = getattr(unwrapped_teacher, 'name_or_path', None) + + # Additional fallback: try to get from unwrapped model's config name_or_path + if teacher_model_path is None: + unwrapped_teacher = self.accelerator.unwrap_model(teacher_model) + if hasattr(unwrapped_teacher, 'config') and hasattr(unwrapped_teacher.config, + 'name_or_path'): + teacher_model_path = unwrapped_teacher.config.name_or_path + if teacher_model_path is None: + # Provide detailed error information for debugging + model_info = f"Model type: {type(teacher_model)}, " + if hasattr(teacher_model, 'config'): + model_info += f"Config type: {type(teacher_model.config)}, " + model_info += f"Config attributes: {[attr for attr in dir(teacher_model.config) if not attr.startswith('_')]}" + else: + model_info += "No config available" + + unwrapped_info = "" + try: + unwrapped = self.accelerator.unwrap_model(teacher_model) + unwrapped_info = f"Unwrapped model type: {type(unwrapped)}, " + if hasattr(unwrapped, 'config'): + unwrapped_info += f"Unwrapped config attributes: {[attr for attr in dir(unwrapped.config) if not attr.startswith('_')]}" + except Exception as e: + unwrapped_info = f"Failed to unwrap model: {e}" + + raise ValueError(f"Cannot determine teacher model path for tokenizer initialization. " + f"Model info: {model_info}. Unwrapped info: {unwrapped_info}") + self.teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_path) + + from .GOLDLossAdapter import GOLDLossAdapter + # Initialize adapter only once + if not hasattr(self, 'gold_adapter'): + self.gold_adapter = GOLDLossAdapter( + config={ + 'use_uld_loss': True, + 'use_extended_uld': True, + 'uld_use_hybrid_loss': True, + 'uld_crossentropy_weight': 0.0, + 'uld_distillation_weight': 1.0, + 'uld_student_temperature': 1.0, + 'uld_teacher_temperature': 1.0, + }, + student_tokenizer=self.student_tokenizer, + teacher_tokenizer=self.teacher_tokenizer, + ) + prompt_texts = inputs['prompt_text'] + completion_texts = inputs['completion_texts'] + + # Add teacher model memory management like in liger branch + load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() + with load_context: + with torch.no_grad(), disable_gradient_checkpointing(teacher_model, + self.args.gradient_checkpointing_kwargs): + ( + teacher_input_ids, + teacher_labels, + teacher_attention_mask, + teacher_prompt_length, + ) = self.build_teacher_inputs_from_texts( + self.teacher_tokenizer, + prompt_texts, + completion_texts + ) + ( + student_input_ids, + student_labels, + student_attention_mask, + student_prompt_length, + ) = self.build_teacher_inputs_from_texts( + self.student_tokenizer, + prompt_texts, + completion_texts + ) + + # Teacher model forward pass (NO gradients) + outputs_teacher = teacher_model( + input_ids=teacher_input_ids, + attention_mask=teacher_attention_mask, + ) + # Student model forward pass (WITH gradients for student parameters) + outputs_student = model( + input_ids=student_input_ids, + attention_mask=student_attention_mask, + ) - opsd_labels = opsd_teacher_inputs.get('labels') if opsd_teacher_inputs is not None else None - teacher_out = TeacherOutput(full_logits=outputs_teacher.logits, opsd_teacher_labels=opsd_labels) - loss = self._compute_jsd_loss(outputs_student.logits, teacher_out, inputs['labels']) + # Ensure teacher_logits has gradient info but teacher model params don't participate + teacher_logits = outputs_teacher.logits.detach().requires_grad_(True) - if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT: - loss = loss + self.args.sft_alpha * outputs_student.loss - # Separate teacher model provided - else: - assert teacher_model is not None - if self.args.sft_alpha > 0: - model_inputs['labels'] = inputs['labels'] - outputs_student = model(**model_inputs) - - t_fwd = teacher_fwd_inputs if teacher_fwd_inputs is not None else { - k: v - for k, v in model_inputs.items() if k != 'labels' - } + # Release intermediate tensors to free memory + del teacher_input_ids, teacher_attention_mask, student_input_ids, student_attention_mask - load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() - with torch.no_grad(), load_context, disable_gradient_checkpointing(teacher_model, - self.args.gradient_checkpointing_kwargs): - outputs_teacher = teacher_model(**t_fwd) + loss = self.gold_adapter( + student_logits=outputs_student.logits, + teacher_logits=teacher_logits, + student_labels=student_labels, + teacher_labels=teacher_labels, + student_input_ids=student_input_ids, + teacher_input_ids=teacher_input_ids, + ) - opsd_labels = opsd_teacher_inputs.get('labels') if opsd_teacher_inputs is not None else None - teacher_out = TeacherOutput(full_logits=outputs_teacher.logits, opsd_teacher_labels=opsd_labels) - loss = self._compute_jsd_loss(outputs_student.logits, teacher_out, inputs['labels']) + loss_total += loss / len(self.teacher_model_group) - if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT: - loss = loss + self.args.sft_alpha * outputs_student.loss + # Return loss + if return_outputs: + if self.use_liger_gkd_loss: + # outputs has been released in liger loss computation to reduce peak memory + outputs_student = None + return (loss_total, outputs_student) + else: + return loss_total + + def build_inputs_from_texts( + self, + tokenizer: PreTrainedTokenizerBase, + prompt_texts: list[str], + completion_texts: list[str], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + """Tokenize teacher prompts/completions and produce tensors ready for GOLD loss.""" + + pad_token_id = tokenizer.pad_token_id + eos_token_id = tokenizer.eos_token_id + + prompt_token_ids = tokenizer(prompt_texts, add_special_tokens=True)["input_ids"] + completion_token_ids = tokenizer(completion_texts, add_special_tokens=False)["input_ids"] + + sequences: list[torch.Tensor] = [] + attention_masks: list[torch.Tensor] = [] + labels_list: list[torch.Tensor] = [] + prompt_lengths: list[int] = [] + # Get device using reliable detection method + device = None + try: + # First try to get device from model parameters + if hasattr(self, 'model') and self.model is not None: + device = next(self.model.parameters()).device + elif hasattr(self, 'teacher_model') and self.teacher_model is not None: + device = next(self.teacher_model.parameters()).device + except (AttributeError, StopIteration): + pass + + # Fallback to default device detection + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + for prompt_ids, completion_ids in zip(prompt_token_ids, completion_token_ids, strict=True): + # Remove trailing EOS from prompt so completions can extend cleanly + if eos_token_id is not None and prompt_ids and prompt_ids[-1] == eos_token_id: + prompt_ids = prompt_ids[:-1] + + prompt_lengths.append(len(prompt_ids)) + sequence = list(prompt_ids) + sequence.extend(completion_ids) + if eos_token_id is not None: + sequence.append(eos_token_id) + + seq_tensor = torch.tensor(sequence, dtype=torch.long, device=device) + sequences.append(seq_tensor) + attention_masks.append(torch.ones_like(seq_tensor)) + labels = seq_tensor.clone() + labels[: len(prompt_ids)] = -100 + if pad_token_id is not None: + labels[labels == pad_token_id] = -100 + labels_list.append(labels) + + teacher_input_ids = pad( + sequences, + padding_side="right", + padding_value=pad_token_id if pad_token_id is not None else 0, + ) + teacher_attention_mask = pad(attention_masks, padding_side="right", padding_value=0).bool() + teacher_labels = pad(labels_list, padding_side="right", padding_value=-100) + + if eos_token_id is not None: + for row in range(teacher_attention_mask.size(0)): + valid = ( + teacher_input_ids[row] != pad_token_id + if pad_token_id is not None + else teacher_attention_mask[row].bool() + ) + if valid.any(): + last_idx = valid.nonzero(as_tuple=True)[0][-1] + teacher_attention_mask[row, last_idx + 1:] = False - # Return loss - if return_outputs: - return (loss, outputs_student) - else: - return loss + teacher_prompt_length = max(prompt_lengths) if prompt_lengths else 0 - def choose_teacher_model(self): - if not self.args.use_mopd: - return self.teacher_model - #todo 使用mopd时从教师模型组选择最佳模型 + return teacher_input_ids, teacher_labels, teacher_attention_mask, teacher_prompt_length def _prepare_batch_inputs(self, inputs: list, encode_prompt_only: bool = False) -> Dict[str, torch.Tensor]: """Prepare batch inputs for training. From dc4be10d20d2db079ea60c8915be821ffc497804 Mon Sep 17 00:00:00 2001 From: Yao Chuming <1416004356@qq.com> Date: Tue, 12 May 2026 10:52:56 +0800 Subject: [PATCH 03/13] [MOPD] rename gold_loss_adapter.py --- swift/rlhf_trainers/gkd_trainer.py | 2 +- .../rlhf_trainers/{GOLDLossAdapter.py => gold_loss_adapter.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename swift/rlhf_trainers/{GOLDLossAdapter.py => gold_loss_adapter.py} (100%) diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index c4990611dd..4746a3bcb5 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -566,7 +566,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N f"Model info: {model_info}. Unwrapped info: {unwrapped_info}") self.teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_path) - from .GOLDLossAdapter import GOLDLossAdapter + from .gold_loss_adapter import GOLDLossAdapter # Initialize adapter only once if not hasattr(self, 'gold_adapter'): self.gold_adapter = GOLDLossAdapter( diff --git a/swift/rlhf_trainers/GOLDLossAdapter.py b/swift/rlhf_trainers/gold_loss_adapter.py similarity index 100% rename from swift/rlhf_trainers/GOLDLossAdapter.py rename to swift/rlhf_trainers/gold_loss_adapter.py From a1fcda7250adec02c2d71c81ddf41ffeb62eee83 Mon Sep 17 00:00:00 2001 From: Chuming Yao <1416004356@qq.com> Date: Tue, 12 May 2026 15:32:18 +0800 Subject: [PATCH 04/13] [MOPD] extract method --- swift/rlhf_trainers/gkd_trainer.py | 122 ++++++----------------- swift/rlhf_trainers/gold_loss_adapter.py | 2 +- 2 files changed, 33 insertions(+), 91 deletions(-) diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index 4746a3bcb5..f8137aa8c1 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -319,7 +319,6 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token # 解码清理后的 token IDs completion_text = self.template.safe_decode(cleaned_ids) - # 进一步清理空白字符(可选) completion_text = completion_text.strip() completion_texts.append(completion_text) @@ -474,96 +473,12 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N else: if not hasattr(self, 'student_tokenizer'): student_model = model - student_model_path = getattr(student_model, 'name_or_path', None) - if student_model_path is None: - # Try to get path from config - if hasattr(student_model, 'config') and hasattr(student_model.config, '_name_or_path'): - student_model_path = student_model.config._name_or_path - - if student_model_path is None: - # If still None, try to get from model's base model - unwrapped_student = self.accelerator.unwrap_model(student_model) - if hasattr(unwrapped_student, 'base_model_prefix'): - base_model = getattr(unwrapped_student, unwrapped_student.base_model_prefix, unwrapped_student) - if hasattr(base_model, 'config') and hasattr(base_model.config, '_name_or_path'): - student_model_path = base_model.config._name_or_path - # Additional fallback: try to get from model's config name_or_path attribute - if student_model_path is None: - if hasattr(student_model, 'config') and hasattr(student_model.config, 'name_or_path'): - student_model_path = student_model.config.name_or_path - - # Additional fallback: try to get from unwrapped model's name_or_path - if student_model_path is None: - unwrapped_student = self.accelerator.unwrap_model(student_model) - student_model_path = getattr(unwrapped_student, 'name_or_path', None) - - # Additional fallback: try to get from unwrapped model's config name_or_path - if student_model_path is None: - unwrapped_student = self.accelerator.unwrap_model(student_model) - if hasattr(unwrapped_student, 'config') and hasattr(unwrapped_student.config, 'name_or_path'): - student_model_path = unwrapped_student.config.name_or_path - - if student_model_path is None: - # Provide detailed error information for debugging - model_info = f"Model type: {type(student_model)}, " - if hasattr(student_model, 'config'): - model_info += f"Config type: {type(student_model.config)}, " - model_info += f"Config attributes: {[attr for attr in dir(student_model.config) if not attr.startswith('_')]}" - else: - model_info += "No config available" + student_model_path = self.get_model_path(student_model) self.student_tokenizer = AutoTokenizer.from_pretrained(student_model_path) # Initialize teacher tokenizer (only once) if not hasattr(self, 'teacher_tokenizer'): - teacher_model_path = getattr(teacher_model, 'name_or_path', None) - if teacher_model_path is None: - # Try to get path from config - if hasattr(teacher_model, 'config') and hasattr(teacher_model.config, '_name_or_path'): - teacher_model_path = teacher_model.config._name_or_path - - if teacher_model_path is None: - # If still None, try to get from model's base model - unwrapped_teacher = self.accelerator.unwrap_model(teacher_model) - if hasattr(unwrapped_teacher, 'base_model_prefix'): - base_model = getattr(unwrapped_teacher, unwrapped_teacher.base_model_prefix, - unwrapped_teacher) - if hasattr(base_model, 'config') and hasattr(base_model.config, '_name_or_path'): - teacher_model_path = base_model.config._name_or_path - # Additional fallback: try to get from model's config name_or_path attribute - if teacher_model_path is None: - if hasattr(teacher_model, 'config') and hasattr(teacher_model.config, 'name_or_path'): - teacher_model_path = teacher_model.config.name_or_path - - # Additional fallback: try to get from unwrapped model's name_or_path - if teacher_model_path is None: - unwrapped_teacher = self.accelerator.unwrap_model(teacher_model) - teacher_model_path = getattr(unwrapped_teacher, 'name_or_path', None) - - # Additional fallback: try to get from unwrapped model's config name_or_path - if teacher_model_path is None: - unwrapped_teacher = self.accelerator.unwrap_model(teacher_model) - if hasattr(unwrapped_teacher, 'config') and hasattr(unwrapped_teacher.config, - 'name_or_path'): - teacher_model_path = unwrapped_teacher.config.name_or_path - if teacher_model_path is None: - # Provide detailed error information for debugging - model_info = f"Model type: {type(teacher_model)}, " - if hasattr(teacher_model, 'config'): - model_info += f"Config type: {type(teacher_model.config)}, " - model_info += f"Config attributes: {[attr for attr in dir(teacher_model.config) if not attr.startswith('_')]}" - else: - model_info += "No config available" - - unwrapped_info = "" - try: - unwrapped = self.accelerator.unwrap_model(teacher_model) - unwrapped_info = f"Unwrapped model type: {type(unwrapped)}, " - if hasattr(unwrapped, 'config'): - unwrapped_info += f"Unwrapped config attributes: {[attr for attr in dir(unwrapped.config) if not attr.startswith('_')]}" - except Exception as e: - unwrapped_info = f"Failed to unwrap model: {e}" - - raise ValueError(f"Cannot determine teacher model path for tokenizer initialization. " - f"Model info: {model_info}. Unwrapped info: {unwrapped_info}") + teacher_model_path = self.get_model_path(teacher_model) + self.teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_path) from .gold_loss_adapter import GOLDLossAdapter @@ -626,7 +541,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N teacher_logits = outputs_teacher.logits.detach().requires_grad_(True) # Release intermediate tensors to free memory - del teacher_input_ids, teacher_attention_mask, student_input_ids, student_attention_mask + del teacher_attention_mask, student_attention_mask loss = self.gold_adapter( student_logits=outputs_student.logits, @@ -723,6 +638,34 @@ def build_inputs_from_texts( return teacher_input_ids, teacher_labels, teacher_attention_mask, teacher_prompt_length + def get_model_path(self, model): + model_path = getattr(model, 'name_or_path', None) + if model_path is None: + # Try to get path from config + if hasattr(model, 'config') and hasattr(model.config, '_name_or_path'): + model_path = model.config._name_or_path + if model_path is None: + # If still None, try to get from model's base model + unwrapped_student = self.accelerator.unwrap_model(model) + if hasattr(unwrapped_student, 'base_model_prefix'): + base_model = getattr(unwrapped_student, unwrapped_student.base_model_prefix, unwrapped_student) + if hasattr(base_model, 'config') and hasattr(base_model.config, '_name_or_path'): + model_path = base_model.config._name_or_path + # Additional fallback: try to get from model's config name_or_path attribute + if model_path is None: + if hasattr(model, 'config') and hasattr(model.config, 'name_or_path'): + model_path = model.config.name_or_path + # Additional fallback: try to get from unwrapped model's name_or_path + if model_path is None: + unwrapped_student = self.accelerator.unwrap_model(model) + model_path = getattr(unwrapped_student, 'name_or_path', None) + # Additional fallback: try to get from unwrapped model's config name_or_path + if model_path is None: + unwrapped_student = self.accelerator.unwrap_model(model) + if hasattr(unwrapped_student, 'config') and hasattr(unwrapped_student.config, 'name_or_path'): + model_path = unwrapped_student.config.name_or_path + return model_path + def _prepare_batch_inputs(self, inputs: list, encode_prompt_only: bool = False) -> Dict[str, torch.Tensor]: """Prepare batch inputs for training. @@ -1068,7 +1011,6 @@ def generalized_jsd_loss( t_log_probs = F.log_softmax(t_chunk, dim=-1) del s_chunk, t_chunk - #todo 使用mopd的计算函数,增加教师模型权重 if beta == 0: jsd_chunk = F.kl_div(s_log_probs, t_log_probs, reduction='none', log_target=True) elif beta == 1: diff --git a/swift/rlhf_trainers/gold_loss_adapter.py b/swift/rlhf_trainers/gold_loss_adapter.py index 37ec431b8f..c4398d0abe 100644 --- a/swift/rlhf_trainers/gold_loss_adapter.py +++ b/swift/rlhf_trainers/gold_loss_adapter.py @@ -141,7 +141,7 @@ def forward( if not self.use_uld_loss: return torch.tensor(0.0, device=student_logits.device, requires_grad=True) - # 1. Cross-entropy loss (可选) + # 1. Cross-entropy loss (可选,通过crossentropy_weight设置权重) crossentropy_loss = self._compute_cross_entropy(student_logits, student_labels) # 2. Distillation loss (ULD) From 8923a548bfac53376f1c32897d2b0213a0d568a0 Mon Sep 17 00:00:00 2001 From: Chuming Yao <1416004356@qq.com> Date: Tue, 12 May 2026 19:38:32 +0800 Subject: [PATCH 05/13] =?UTF-8?q?[MOPD]=20mopd=E7=9B=B8=E5=85=B3=E8=B6=85?= =?UTF-8?q?=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- swift/arguments/rlhf_args.py | 2 + swift/pipelines/train/rlhf.py | 4 + swift/rlhf_trainers/gkd_trainer.py | 488 +++++++++++++++-------------- 3 files changed, 265 insertions(+), 229 deletions(-) diff --git a/swift/arguments/rlhf_args.py b/swift/arguments/rlhf_args.py index 5b08153d21..476f3392dd 100644 --- a/swift/arguments/rlhf_args.py +++ b/swift/arguments/rlhf_args.py @@ -66,6 +66,8 @@ class TeacherModelArguments: """ teacher_model: Optional[str] = None teacher_model_group: List[str] = field(default_factory=list) + #todo 还需要增加mopd_config + use_mopd: bool = False teacher_adapters: List[str] = field(default_factory=list) teacher_model_type: Optional[str] = field( default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'}) diff --git a/swift/pipelines/train/rlhf.py b/swift/pipelines/train/rlhf.py index 6935e34f31..a5f2626a1f 100644 --- a/swift/pipelines/train/rlhf.py +++ b/swift/pipelines/train/rlhf.py @@ -300,6 +300,10 @@ def _get_trainer_kwargs(self): else: trainer_kwargs['teacher_model_group'] = self.args.teacher_model_group trainer_kwargs['teacher_use_disable_adapter'] = getattr(self.args, '_teacher_use_disable_adapter', False) + if self.args.use_mopd: + trainer_kwargs['use_mopd'] = self.args.use_mopd + #todo + # trainer_kwargs['mopd_config'] = self.args.mopd_config return trainer_kwargs diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index f8137aa8c1..56bfe1adf5 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -84,6 +84,7 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non teacher_model = kwargs.pop('teacher_model', None) self.teacher_model_group = kwargs.pop('teacher_model_group', None) teacher_deepspeed_config = kwargs.pop('teacher_deepspeed_config', None) + self.use_mopd = kwargs.pop('use_mopd', False) self.vllm_client = kwargs.pop('vllm_client', None) self.gkd_logits_topk = kwargs.pop('gkd_logits_topk', None) teacher_model_server = kwargs.pop('teacher_model_server', None) @@ -132,7 +133,8 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non self.teacher_model.eval() if self.args.offload_teacher_model: self.offload_model(self.accelerator.unwrap_model(self.teacher_model)) - + else: + self.teacher_model = None # Initialize teacher model group (for MOPD) if self.teacher_model_group is not None and len(self.teacher_model_group) > 0: prepared_models = [] @@ -169,7 +171,7 @@ def _get_data_collator(self, args, template): def _build_opsd_teacher_data(self, inputs): """Build teacher data for OPSD by replacing the last user message with teacher_prompt. - Returns None if teacher_prompt is not available in all examples. + Returns None if teacher_prompt is not av ailable in all examples. """ if not all('teacher_prompt' in data and data['teacher_prompt'] for data in inputs): return None @@ -299,30 +301,32 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token new_position_ids = new_attention_mask.cumsum(dim=1) - 1 new_position_ids[new_position_ids < 0] = 0 inputs['position_ids'] = new_position_ids - # 返回解码后的文本encoded_inputs - batch_size = generated_tokens.shape[0] - prompt_length = prompt_input_ids.shape[1] - - completion_texts = [] - pad_token_id = self.processing_class.pad_token_id - eos_token_id = self.processing_class.eos_token_id - for i in range(batch_size): - # Decode completion - completion_ids = generated_tokens[i][prompt_length:].tolist() - # 截断到第一个 EOS 或 PAD token - cleaned_ids = [] - for token_id in completion_ids: - if token_id == eos_token_id or token_id == pad_token_id: - break - cleaned_ids.append(token_id) - - # 解码清理后的 token IDs - completion_text = self.template.safe_decode(cleaned_ids) - - completion_text = completion_text.strip() - - completion_texts.append(completion_text) - + if self.use_mopd: + # 返回解码后的文本encoded_inputs + batch_size = generated_tokens.shape[0] + prompt_length = prompt_input_ids.shape[1] + + completion_texts = [] + pad_token_id = self.processing_class.pad_token_id + eos_token_id = self.processing_class.eos_token_id + for i in range(batch_size): + # Decode completion + completion_ids = generated_tokens[i][prompt_length:].tolist() + # 截断到第一个 EOS 或 PAD token + cleaned_ids = [] + for token_id in completion_ids: + if token_id == eos_token_id or token_id == pad_token_id: + break + cleaned_ids.append(token_id) + + # 解码清理后的 token IDs + completion_text = self.template.safe_decode(cleaned_ids) + + completion_text = completion_text.strip() + + completion_texts.append(completion_text) + else: + completion_texts = None return generated_tokens, new_attention_mask, new_labels, completion_texts @profiling_decorator @@ -347,130 +351,134 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N self.prepare_logits_to_keep(inputs) model_inputs['logits_to_keep'] = inputs['logits_to_keep'] - loss_total = 0.0 - if self.teacher_model_group is None: - # Use single teacher model - teacher_model_group = [self.teacher_model] - else: - teacher_model_group = self.teacher_model_group - for teacher_model in teacher_model_group: - if self.use_liger_gkd_loss: - # Liger fused JSD loss for memory efficiency - # Get base models (exclude lm_head to save memory) - unwrapped_student = self.accelerator.unwrap_model(model) - if is_peft_model(unwrapped_student): - unwrapped_student = unwrapped_student.base_model.model - base_student = getattr(unwrapped_student, getattr(unwrapped_student, 'base_model_prefix', 'model'), - unwrapped_student) - - unwrapped_teacher = self.accelerator.unwrap_model(teacher_model) - base_teacher = getattr(unwrapped_teacher, getattr(unwrapped_teacher, 'base_model_prefix', 'model'), - unwrapped_teacher) - - # Forward through base models - student_outputs = base_student(**model_inputs, use_cache=False) + if self.use_liger_gkd_loss: + # Liger fused JSD loss for memory efficiency + # Get base models (exclude lm_head to save memory) + unwrapped_student = self.accelerator.unwrap_model(model) + if is_peft_model(unwrapped_student): + unwrapped_student = unwrapped_student.base_model.model + base_student = getattr(unwrapped_student, getattr(unwrapped_student, 'base_model_prefix', 'model'), + unwrapped_student) + + unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) + base_teacher = getattr(unwrapped_teacher, getattr(unwrapped_teacher, 'base_model_prefix', 'model'), + unwrapped_teacher) + + # Forward through base models + student_outputs = base_student(**model_inputs, use_cache=False) + + load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() + with load_context: + with torch.no_grad(), disable_gradient_checkpointing(self.teacher_model, + self.args.gradient_checkpointing_kwargs): + teacher_outputs = base_teacher(**model_inputs, use_cache=False) + + # Get hidden states (shifted) + student_hidden = student_outputs.last_hidden_state[:, :-1] + teacher_hidden = teacher_outputs.last_hidden_state[:, :-1] + + # Release full outputs to free memory + del student_outputs, teacher_outputs + + # Prepare labels (shifted) + labels_mask = inputs['labels'] != -100 + masked_input_ids = torch.where(labels_mask, inputs['input_ids'], + torch.full_like(inputs['input_ids'], -100)) + true_labels = masked_input_ids[:, 1:].contiguous() + + # Release intermediate tensors + del labels_mask, masked_input_ids + + # Get output heads + student_head = unwrapped_student.get_output_embeddings() + teacher_head = unwrapped_teacher.get_output_embeddings() + + # Prepare context managers for gathering parameters in zero3 + teacher_context = get_gather_if_zero3_context(self, is_zero3=self.is_teacher_ds3)(teacher_head.weight) + student_context = get_gather_if_zero3_context(self)(student_head.weight) + + with teacher_context, student_context: + # Compute liger fused JSD loss + loss = self.liger_jsd_loss( + student_input=student_hidden, + student_weight=student_head.weight, + teacher_input=teacher_hidden, + teacher_weight=teacher_head.weight, + true_labels=true_labels, + student_bias=getattr(student_head, 'bias', None), + teacher_bias=getattr(teacher_head, 'bias', None), + ) + # Release hidden states after loss computation + del student_hidden, teacher_hidden, true_labels + outputs_student = None + # Teacher API mode: top-k logprobs fetched from external teacher server + elif self.use_teacher_api: + assert teacher_api_logprobs is not None + if self.args.sft_alpha > 0: + model_inputs['labels'] = inputs['labels'] + outputs_student = model(**model_inputs) + + # teacher_api shape: [batch, seq_len-1, topk] + # Pad to [batch, seq_len, topk] so it aligns with student logits. + teacher_api_logprobs = F.pad(teacher_api_logprobs, (0, 0, 0, 1), value=float('-inf')) + teacher_api_indices = F.pad(teacher_api_indices, (0, 0, 0, 1), value=0) + logits_to_keep = inputs.get('logits_to_keep') + if logits_to_keep is not None: + if isinstance(logits_to_keep, torch.Tensor) and logits_to_keep.dtype == torch.bool: + teacher_api_logprobs = teacher_api_logprobs[:, logits_to_keep] + teacher_api_indices = teacher_api_indices[:, logits_to_keep] + else: + n = logits_to_keep.item() if isinstance(logits_to_keep, torch.Tensor) else int(logits_to_keep) + teacher_api_logprobs = teacher_api_logprobs[:, -n:] + teacher_api_indices = teacher_api_indices[:, -n:] + + opsd_labels = opsd_teacher_inputs.get('labels') if opsd_teacher_inputs is not None else None + teacher_out = TeacherOutput( + topk_logprobs=teacher_api_logprobs, + topk_indices=teacher_api_indices, + opsd_teacher_labels=opsd_labels, + ) + loss = self._compute_jsd_loss(outputs_student.logits, teacher_out, inputs['labels']) + + if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT: + loss = loss + self.args.sft_alpha * outputs_student.loss + # Self-distillation mode: student model doubles as teacher + elif self._is_self_distillation: + if self.args.sft_alpha > 0: + model_inputs['labels'] = inputs['labels'] + outputs_student = model(**model_inputs) + + t_fwd = teacher_fwd_inputs if teacher_fwd_inputs is not None else { + k: v + for k, v in model_inputs.items() if k != 'labels' + } - load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() - with load_context: - with torch.no_grad(), disable_gradient_checkpointing(teacher_model, - self.args.gradient_checkpointing_kwargs): - teacher_outputs = base_teacher(**model_inputs, use_cache=False) - - # Get hidden states (shifted) - student_hidden = student_outputs.last_hidden_state[:, :-1] - teacher_hidden = teacher_outputs.last_hidden_state[:, :-1] - - # Release full outputs to free memory - del student_outputs, teacher_outputs - - # Prepare labels (shifted) - labels_mask = inputs['labels'] != -100 - masked_input_ids = torch.where(labels_mask, inputs['input_ids'], - torch.full_like(inputs['input_ids'], -100)) - true_labels = masked_input_ids[:, 1:].contiguous() - - # Release intermediate tensors - del labels_mask, masked_input_ids - - # Get output heads - student_head = unwrapped_student.get_output_embeddings() - teacher_head = unwrapped_teacher.get_output_embeddings() - - # Prepare context managers for gathering parameters in zero3 - teacher_context = get_gather_if_zero3_context(self, is_zero3=self.is_teacher_ds3)(teacher_head.weight) - student_context = get_gather_if_zero3_context(self)(student_head.weight) - - with teacher_context, student_context: - # Compute liger fused JSD loss - loss = self.liger_jsd_loss( - student_input=student_hidden, - student_weight=student_head.weight, - teacher_input=teacher_hidden, - teacher_weight=teacher_head.weight, - true_labels=true_labels, - student_bias=getattr(student_head, 'bias', None), - teacher_bias=getattr(teacher_head, 'bias', None), - ) - # Release hidden states after loss computation - del student_hidden, teacher_hidden, true_labels - outputs_student = None - # Teacher API mode: top-k logprobs fetched from external teacher server - elif self.use_teacher_api: - assert teacher_api_logprobs is not None - if self.args.sft_alpha > 0: - model_inputs['labels'] = inputs['labels'] - outputs_student = model(**model_inputs) - - # teacher_api shape: [batch, seq_len-1, topk] - # Pad to [batch, seq_len, topk] so it aligns with student logits. - teacher_api_logprobs = F.pad(teacher_api_logprobs, (0, 0, 0, 1), value=float('-inf')) - teacher_api_indices = F.pad(teacher_api_indices, (0, 0, 0, 1), value=0) - logits_to_keep = inputs.get('logits_to_keep') - if logits_to_keep is not None: - if isinstance(logits_to_keep, torch.Tensor) and logits_to_keep.dtype == torch.bool: - teacher_api_logprobs = teacher_api_logprobs[:, logits_to_keep] - teacher_api_indices = teacher_api_indices[:, logits_to_keep] - else: - n = logits_to_keep.item() if isinstance(logits_to_keep, torch.Tensor) else int(logits_to_keep) - teacher_api_logprobs = teacher_api_logprobs[:, -n:] - teacher_api_indices = teacher_api_indices[:, -n:] - - opsd_labels = opsd_teacher_inputs.get('labels') if opsd_teacher_inputs is not None else None - teacher_out = TeacherOutput( - topk_logprobs=teacher_api_logprobs, - topk_indices=teacher_api_indices, - opsd_teacher_labels=opsd_labels, - ) - loss = self._compute_jsd_loss(outputs_student.logits, teacher_out, inputs['labels']) - - if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT: - loss = loss + self.args.sft_alpha * outputs_student.loss - # Self-distillation mode: student model doubles as teacher - elif self._is_self_distillation: - if self.args.sft_alpha > 0: - model_inputs['labels'] = inputs['labels'] - outputs_student = model(**model_inputs) - - t_fwd = teacher_fwd_inputs if teacher_fwd_inputs is not None else { - k: v - for k, v in model_inputs.items() if k != 'labels' - } - - adapter_ctx = ( - self.accelerator.unwrap_model(model).disable_adapter() - if self._teacher_use_disable_adapter else nullcontext()) - with torch.no_grad(), adapter_ctx, \ - disable_gradient_checkpointing(model, self.args.gradient_checkpointing_kwargs): - outputs_teacher = model(**t_fwd) - - opsd_labels = opsd_teacher_inputs.get('labels') if opsd_teacher_inputs is not None else None - teacher_out = TeacherOutput(full_logits=outputs_teacher.logits, opsd_teacher_labels=opsd_labels) - loss = self._compute_jsd_loss(outputs_student.logits, teacher_out, inputs['labels']) - - if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT: - loss = loss + self.args.sft_alpha * outputs_student.loss - # Separate teacher model provided + adapter_ctx = ( + self.accelerator.unwrap_model(model).disable_adapter() + if self._teacher_use_disable_adapter else nullcontext()) + with torch.no_grad(), adapter_ctx, \ + disable_gradient_checkpointing(model, self.args.gradient_checkpointing_kwargs): + outputs_teacher = model(**t_fwd) + + opsd_labels = opsd_teacher_inputs.get('labels') if opsd_teacher_inputs is not None else None + teacher_out = TeacherOutput(full_logits=outputs_teacher.logits, opsd_teacher_labels=opsd_labels) + loss = self._compute_jsd_loss(outputs_student.logits, teacher_out, inputs['labels']) + + if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT: + loss = loss + self.args.sft_alpha * outputs_student.loss + # Separate teacher model provided + elif self.use_mopd: + if self.teacher_model_group is None: + # Use single teacher model + teacher_model_group = [self.teacher_model] else: + teacher_model_group = self.teacher_model_group + # 预先计算教师模型数量,避免除零错误和重复计算 + num_teacher_models = len(teacher_model_group) + if num_teacher_models == 0: + raise ValueError("teacher_model_group cannot be empty") + loss = torch.tensor(0.0, device=model.device) + for teacher_model in teacher_model_group: if not hasattr(self, 'student_tokenizer'): student_model = model student_model_path = self.get_model_path(student_model) @@ -510,7 +518,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N teacher_labels, teacher_attention_mask, teacher_prompt_length, - ) = self.build_teacher_inputs_from_texts( + ) = self.build_inputs_from_texts( self.teacher_tokenizer, prompt_texts, completion_texts @@ -520,7 +528,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N student_labels, student_attention_mask, student_prompt_length, - ) = self.build_teacher_inputs_from_texts( + ) = self.build_inputs_from_texts( self.student_tokenizer, prompt_texts, completion_texts @@ -543,7 +551,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # Release intermediate tensors to free memory del teacher_attention_mask, student_attention_mask - loss = self.gold_adapter( + loss_item = self.gold_adapter( student_logits=outputs_student.logits, teacher_logits=teacher_logits, student_labels=student_labels, @@ -552,91 +560,113 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N teacher_input_ids=teacher_input_ids, ) - loss_total += loss / len(self.teacher_model_group) + loss = loss + loss_item / num_teacher_models +# ... existing code ... + + # Separate teacher model provided + else: + assert self.teacher_model is not None + if self.args.sft_alpha > 0: + model_inputs['labels'] = inputs['labels'] + outputs_student = model(**model_inputs) + + t_fwd = teacher_fwd_inputs if teacher_fwd_inputs is not None else { + k: v + for k, v in model_inputs.items() if k != 'labels' + } + + load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() + with torch.no_grad(), load_context, disable_gradient_checkpointing(self.teacher_model, + self.args.gradient_checkpointing_kwargs): + outputs_teacher = self.teacher_model(**t_fwd) + + opsd_labels = opsd_teacher_inputs.get('labels') if opsd_teacher_inputs is not None else None + teacher_out = TeacherOutput(full_logits=outputs_teacher.logits, opsd_teacher_labels=opsd_labels) + loss = self._compute_jsd_loss(outputs_student.logits, teacher_out, inputs['labels']) + + if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT: + loss = loss + self.args.sft_alpha * outputs_student.loss # Return loss if return_outputs: - if self.use_liger_gkd_loss: - # outputs has been released in liger loss computation to reduce peak memory - outputs_student = None - return (loss_total, outputs_student) + return (loss, outputs_student) else: - return loss_total - - def build_inputs_from_texts( - self, - tokenizer: PreTrainedTokenizerBase, - prompt_texts: list[str], - completion_texts: list[str], - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: - """Tokenize teacher prompts/completions and produce tensors ready for GOLD loss.""" - - pad_token_id = tokenizer.pad_token_id - eos_token_id = tokenizer.eos_token_id - - prompt_token_ids = tokenizer(prompt_texts, add_special_tokens=True)["input_ids"] - completion_token_ids = tokenizer(completion_texts, add_special_tokens=False)["input_ids"] - - sequences: list[torch.Tensor] = [] - attention_masks: list[torch.Tensor] = [] - labels_list: list[torch.Tensor] = [] - prompt_lengths: list[int] = [] - # Get device using reliable detection method - device = None - try: - # First try to get device from model parameters - if hasattr(self, 'model') and self.model is not None: - device = next(self.model.parameters()).device - elif hasattr(self, 'teacher_model') and self.teacher_model is not None: - device = next(self.teacher_model.parameters()).device - except (AttributeError, StopIteration): - pass - - # Fallback to default device detection - if device is None: - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - for prompt_ids, completion_ids in zip(prompt_token_ids, completion_token_ids, strict=True): - # Remove trailing EOS from prompt so completions can extend cleanly - if eos_token_id is not None and prompt_ids and prompt_ids[-1] == eos_token_id: - prompt_ids = prompt_ids[:-1] - - prompt_lengths.append(len(prompt_ids)) - sequence = list(prompt_ids) - sequence.extend(completion_ids) - if eos_token_id is not None: - sequence.append(eos_token_id) - - seq_tensor = torch.tensor(sequence, dtype=torch.long, device=device) - sequences.append(seq_tensor) - attention_masks.append(torch.ones_like(seq_tensor)) - labels = seq_tensor.clone() - labels[: len(prompt_ids)] = -100 - if pad_token_id is not None: - labels[labels == pad_token_id] = -100 - labels_list.append(labels) - - teacher_input_ids = pad( - sequences, - padding_side="right", - padding_value=pad_token_id if pad_token_id is not None else 0, - ) - teacher_attention_mask = pad(attention_masks, padding_side="right", padding_value=0).bool() - teacher_labels = pad(labels_list, padding_side="right", padding_value=-100) - + return loss + + def build_inputs_from_texts( + self, + tokenizer: PreTrainedTokenizerBase, + prompt_texts: list[str], + completion_texts: list[str], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + """Tokenize teacher prompts/completions and produce tensors ready for GOLD loss.""" + + pad_token_id = tokenizer.pad_token_id + eos_token_id = tokenizer.eos_token_id + + prompt_token_ids = tokenizer(prompt_texts, add_special_tokens=True)["input_ids"] + completion_token_ids = tokenizer(completion_texts, add_special_tokens=False)["input_ids"] + + sequences: list[torch.Tensor] = [] + attention_masks: list[torch.Tensor] = [] + labels_list: list[torch.Tensor] = [] + prompt_lengths: list[int] = [] + # Get device using reliable detection method + device = None + try: + # First try to get device from model parameters + if hasattr(self, 'model') and self.model is not None: + device = next(self.model.parameters()).device + elif hasattr(self, 'teacher_model') and self.teacher_model is not None: + device = next(self.teacher_model.parameters()).device + except (AttributeError, StopIteration): + pass + + # Fallback to default device detection + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + for prompt_ids, completion_ids in zip(prompt_token_ids, completion_token_ids, strict=True): + # Remove trailing EOS from prompt so completions can extend cleanly + if eos_token_id is not None and prompt_ids and prompt_ids[-1] == eos_token_id: + prompt_ids = prompt_ids[:-1] + + prompt_lengths.append(len(prompt_ids)) + sequence = list(prompt_ids) + sequence.extend(completion_ids) if eos_token_id is not None: - for row in range(teacher_attention_mask.size(0)): - valid = ( - teacher_input_ids[row] != pad_token_id - if pad_token_id is not None - else teacher_attention_mask[row].bool() - ) - if valid.any(): - last_idx = valid.nonzero(as_tuple=True)[0][-1] - teacher_attention_mask[row, last_idx + 1:] = False + sequence.append(eos_token_id) + + seq_tensor = torch.tensor(sequence, dtype=torch.long, device=device) + sequences.append(seq_tensor) + attention_masks.append(torch.ones_like(seq_tensor)) + labels = seq_tensor.clone() + labels[: len(prompt_ids)] = -100 + if pad_token_id is not None: + labels[labels == pad_token_id] = -100 + labels_list.append(labels) + + teacher_input_ids = pad( + sequences, + padding_side="right", + padding_value=pad_token_id if pad_token_id is not None else 0, + ) + teacher_attention_mask = pad(attention_masks, padding_side="right", padding_value=0).bool() + teacher_labels = pad(labels_list, padding_side="right", padding_value=-100) + + if eos_token_id is not None: + for row in range(teacher_attention_mask.size(0)): + valid = ( + teacher_input_ids[row] != pad_token_id + if pad_token_id is not None + else teacher_attention_mask[row].bool() + ) + if valid.any(): + last_idx = valid.nonzero(as_tuple=True)[0][-1] + teacher_attention_mask[row, last_idx + 1:] = False - teacher_prompt_length = max(prompt_lengths) if prompt_lengths else 0 + teacher_prompt_length = max(prompt_lengths) if prompt_lengths else 0 - return teacher_input_ids, teacher_labels, teacher_attention_mask, teacher_prompt_length + return teacher_input_ids, teacher_labels, teacher_attention_mask, teacher_prompt_length def get_model_path(self, model): model_path = getattr(model, 'name_or_path', None) From 0bd9b5548691bc1a836d00dde4bacb73940292c7 Mon Sep 17 00:00:00 2001 From: Chuming Yao <1416004356qq@gmail.com> Date: Fri, 15 May 2026 09:49:21 +0800 Subject: [PATCH 06/13] lint fix --- swift/arguments/rlhf_args.py | 10 +++--- swift/pipelines/train/rlhf.py | 8 ++--- swift/rlhf_trainers/gkd_trainer.py | 40 +++++++++++++----------- swift/rlhf_trainers/gold_loss_adapter.py | 6 ++-- 4 files changed, 34 insertions(+), 30 deletions(-) diff --git a/swift/arguments/rlhf_args.py b/swift/arguments/rlhf_args.py index 476f3392dd..b8b5808886 100644 --- a/swift/arguments/rlhf_args.py +++ b/swift/arguments/rlhf_args.py @@ -66,7 +66,7 @@ class TeacherModelArguments: """ teacher_model: Optional[str] = None teacher_model_group: List[str] = field(default_factory=list) - #todo 还需要增加mopd_config + # todo 还需要增加mopd_config use_mopd: bool = False teacher_adapters: List[str] = field(default_factory=list) teacher_model_type: Optional[str] = field( @@ -76,15 +76,15 @@ class TeacherModelArguments: default=None, metadata={ 'help': - 'DeepSpeed configuration for teacher model. ' - 'Can be a path to a json file or one of: zero0, zero1, zero2, zero3, zero2_offload, zero3_offload' + 'DeepSpeed configuration for teacher model. ' + 'Can be a path to a json file or one of: zero0, zero1, zero2, zero3, zero2_offload, zero3_offload' }) teacher_model_server: Optional[str] = field( default=None, metadata={ 'help': - 'URL of the teacher model server (e.g., http://localhost:8000). ' - 'When set, teacher logprobs are fetched via API instead of loading a local model.' + 'URL of the teacher model server (e.g., http://localhost:8000). ' + 'When set, teacher logprobs are fetched via API instead of loading a local model.' }) diff --git a/swift/pipelines/train/rlhf.py b/swift/pipelines/train/rlhf.py index a5f2626a1f..6e70606238 100644 --- a/swift/pipelines/train/rlhf.py +++ b/swift/pipelines/train/rlhf.py @@ -142,7 +142,7 @@ def _prepare_model_tokenizer(self): # Use teacher_model_type and teacher_model_revision if available, otherwise infer model_type = getattr(args, 'teacher_model_type', None) model_revision = getattr(args, 'teacher_model_revision', None) - + result = self._prepare_single_model_for_teacher_group(teacher_model_path, model_type, model_revision) if result is not None: model, _ = result @@ -187,7 +187,7 @@ def _prepare_model_tokenizer(self): def _prepare_single_model_for_teacher_group(self, model_id_or_path, model_type, model_revision): """Prepare a single model for teacher_model_group.""" args = self.args - + if model_type is None: model_info, _ = get_model_info_meta(model_id_or_path) model_type = model_info.model_type @@ -200,7 +200,7 @@ def _prepare_single_model_for_teacher_group(self, model_id_or_path, model_type, hub_token=args.hub_token, ) task_type, num_labels = self._get_model_task_type(model_dir) - + context = nullcontext() if args.teacher_deepspeed: if args.teacher_deepspeed.get('zero_optimization', {}).get('stage') != 3: @@ -302,7 +302,7 @@ def _get_trainer_kwargs(self): trainer_kwargs['teacher_use_disable_adapter'] = getattr(self.args, '_teacher_use_disable_adapter', False) if self.args.use_mopd: trainer_kwargs['use_mopd'] = self.args.use_mopd - #todo + # todo # trainer_kwargs['mopd_config'] = self.args.mopd_config return trainer_kwargs diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index 56bfe1adf5..50ae4569a4 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -2,22 +2,23 @@ import inspect import os import random -import torch -import torch.nn as nn -import torch.nn.functional as F -import trl -from accelerate.utils import gather_object, is_peft_model from collections import defaultdict, deque from contextlib import contextmanager, nullcontext from copy import deepcopy from dataclasses import dataclass from enum import Enum -from packaging import version -from transformers import PreTrainedModel -from trl import SFTTrainer as HFSFTTrainer from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import trl +from accelerate.utils import gather_object, is_peft_model +from packaging import version from transformers import AutoTokenizer +from transformers import PreTrainedModel from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from trl import SFTTrainer as HFSFTTrainer from trl.trainer.utils import pad from swift.template import TemplateInputs @@ -30,6 +31,7 @@ try: from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss + _liger_kernel_available = True except ImportError: _liger_kernel_available = False @@ -561,7 +563,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N ) loss = loss + loss_item / num_teacher_models -# ... existing code ... + # ... existing code ... # Separate teacher model provided else: @@ -979,16 +981,16 @@ def _align_vocab_size(student_logits, teacher_logits): return student_logits, teacher_logits def generalized_jsd_loss( - self, - student_logits, - teacher_logits=None, - labels=None, - beta=0.5, - temperature=1.0, - chunk_size=512, - topk=None, - teacher_topk_logprobs=None, - teacher_topk_indices=None, + self, + student_logits, + teacher_logits=None, + labels=None, + beta=0.5, + temperature=1.0, + chunk_size=512, + topk=None, + teacher_topk_logprobs=None, + teacher_topk_indices=None, ): # Align vocab sizes when student and teacher have different vocabulary dimensions if teacher_logits is not None: diff --git a/swift/rlhf_trainers/gold_loss_adapter.py b/swift/rlhf_trainers/gold_loss_adapter.py index c4398d0abe..e03372954b 100644 --- a/swift/rlhf_trainers/gold_loss_adapter.py +++ b/swift/rlhf_trainers/gold_loss_adapter.py @@ -1,7 +1,8 @@ +from typing import Optional, Tuple, List + import torch import torch.nn as nn import torch.nn.functional as F -from typing import Optional, Tuple, List from transformers import PreTrainedTokenizerBase @@ -226,6 +227,7 @@ def decode_tokens(tokenizer, token_ids): pieces.append(cur[len(prev):]) prev = cur return pieces + student_token_ids = student_input_ids[i, s_start:s_start + s_size].tolist() teacher_token_ids = teacher_input_ids[i, t_start:t_start + t_size].tolist() @@ -528,4 +530,4 @@ def _compute_jsd_for_matched( if torch.isnan(jsd) or torch.isinf(jsd): return torch.tensor(0.0, device=student_probs.device, requires_grad=True) - return jsd \ No newline at end of file + return jsd From 5667586272d1113de86d62847395067ed2ea831d Mon Sep 17 00:00:00 2001 From: Chuming Yao <1416004356qq@gmail.com> Date: Mon, 18 May 2026 09:58:12 +0800 Subject: [PATCH 07/13] [MOPD] examples --- examples/train/rlhf/mopd/mopd.sh | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 examples/train/rlhf/mopd/mopd.sh diff --git a/examples/train/rlhf/mopd/mopd.sh b/examples/train/rlhf/mopd/mopd.sh new file mode 100644 index 0000000000..76eca2edc1 --- /dev/null +++ b/examples/train/rlhf/mopd/mopd.sh @@ -0,0 +1,31 @@ +NPROC_PER_NODE=7 \ +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 \ +swift rlhf \ + --rlhf_type gkd \ + --model Qwen/Qwen3-8B-Base \ + --teacher_model_group Qwen/Qwen3-32B,AI-ModelScope/Skywork-Reward-Llama-3.1-8B-v0.2 \ + --use_mopd true \ + --tuner_type full \ + --dataset open-thoughts/OpenThoughts3-1.2M#10000 \ + --seq_kd false \ + --lmbda 1 \ + --beta 1 \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-5 \ + --gradient_accumulation_steps 1 \ + --save_steps 1000 \ + --save_total_limit 2 \ + --logging_steps 1 \ + --max_length 16000 \ + --max_completion_length 8192 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --save_only_model true \ + --dataloader_num_workers 64 \ + --dataset_num_proc 4 \ + --deepspeed zero3 \ + --teacher_deepspeed zero3 \ + From a227202489033fc6c5a1cebc879aec272d50be5b Mon Sep 17 00:00:00 2001 From: Chuming Yao <1416004356qq@gmail.com> Date: Mon, 18 May 2026 16:41:08 +0800 Subject: [PATCH 08/13] [MOPD] lint fix --- swift/rlhf_trainers/gkd_trainer.py | 6 ------ swift/rlhf_trainers/gold_loss_adapter.py | 9 --------- 2 files changed, 15 deletions(-) diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index 3de66bd821..a48e475f67 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -3,11 +3,6 @@ import os import random import re -import torch -import torch.nn as nn -import torch.nn.functional as F -import trl -from accelerate.utils import gather_object, is_peft_model from collections import defaultdict, deque from contextlib import contextmanager, nullcontext from copy import deepcopy @@ -26,7 +21,6 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase from trl import SFTTrainer as HFSFTTrainer from trl.trainer.utils import RepeatSampler, pad -from typing import Dict, Optional, Union from swift.infer_engine.protocol import MultiModalRequestMixin from swift.template import TemplateInputs diff --git a/swift/rlhf_trainers/gold_loss_adapter.py b/swift/rlhf_trainers/gold_loss_adapter.py index e03372954b..5d124b7c2c 100644 --- a/swift/rlhf_trainers/gold_loss_adapter.py +++ b/swift/rlhf_trainers/gold_loss_adapter.py @@ -219,15 +219,6 @@ def _compute_distillation_loss( student_probs = F.softmax(student_ans_logits / self.student_temperature, dim=-1) teacher_probs = F.softmax(teacher_ans_logits / self.teacher_temperature, dim=-1) - def decode_tokens(tokenizer, token_ids): - pieces = [] - prev = "" - for k in range(len(token_ids)): - cur = tokenizer.decode(token_ids[:k + 1], skip_special_tokens=False) - pieces.append(cur[len(prev):]) - prev = cur - return pieces - student_token_ids = student_input_ids[i, s_start:s_start + s_size].tolist() teacher_token_ids = teacher_input_ids[i, t_start:t_start + t_size].tolist() From 4f75226473d76a8b879652865cf644dff0eb018d Mon Sep 17 00:00:00 2001 From: Chuming Yao <1416004356qq@gmail.com> Date: Mon, 18 May 2026 17:01:04 +0800 Subject: [PATCH 09/13] [MOPD] lint fix --- swift/rlhf_trainers/gkd_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index a48e475f67..32cdc557c9 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -543,7 +543,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N completion_texts = inputs['completion_texts'] # Add teacher model memory management like in liger branch - load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() + load_context = self.load_teacher_model_context() if self.args.offload_teacher_model \ + else nullcontext() with load_context: with torch.no_grad(), disable_gradient_checkpointing(teacher_model, self.args.gradient_checkpointing_kwargs): From e1e58af2998cc1913e7de4367ecaa361259e8bdb Mon Sep 17 00:00:00 2001 From: Chuming Yao <1416004356qq@gmail.com> Date: Mon, 18 May 2026 17:30:17 +0800 Subject: [PATCH 10/13] [MOPD] isort fix --- swift/rlhf_trainers/gkd_trainer.py | 4 ++-- swift/rlhf_trainers/gold_loss_adapter.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index 32cdc557c9..55eb524905 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -16,8 +16,7 @@ import trl from accelerate.utils import gather_object, is_peft_model from packaging import version -from transformers import AutoTokenizer -from transformers import PreTrainedModel +from transformers import AutoTokenizer, PreTrainedModel from transformers.tokenization_utils_base import PreTrainedTokenizerBase from trl import SFTTrainer as HFSFTTrainer from trl.trainer.utils import RepeatSampler, pad @@ -524,6 +523,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N self.teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_path) from .gold_loss_adapter import GOLDLossAdapter + # Initialize adapter only once if not hasattr(self, 'gold_adapter'): self.gold_adapter = GOLDLossAdapter( diff --git a/swift/rlhf_trainers/gold_loss_adapter.py b/swift/rlhf_trainers/gold_loss_adapter.py index 5d124b7c2c..55d183cd82 100644 --- a/swift/rlhf_trainers/gold_loss_adapter.py +++ b/swift/rlhf_trainers/gold_loss_adapter.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, List +from typing import List, Optional, Tuple import torch import torch.nn as nn From b53a0f3ca4fad84bb86f38250bd2706848cc95d8 Mon Sep 17 00:00:00 2001 From: Chuming Yao <1416004356qq@gmail.com> Date: Tue, 19 May 2026 10:51:07 +0800 Subject: [PATCH 11/13] [MOPD] Loop optimization --- examples/train/rlhf/mopd/mopd.sh | 5 +- swift/rlhf_trainers/gkd_trainer.py | 180 ++++++++++++++--------------- 2 files changed, 91 insertions(+), 94 deletions(-) diff --git a/examples/train/rlhf/mopd/mopd.sh b/examples/train/rlhf/mopd/mopd.sh index 76eca2edc1..8013bf6b48 100644 --- a/examples/train/rlhf/mopd/mopd.sh +++ b/examples/train/rlhf/mopd/mopd.sh @@ -4,7 +4,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 \ swift rlhf \ --rlhf_type gkd \ --model Qwen/Qwen3-8B-Base \ - --teacher_model_group Qwen/Qwen3-32B,AI-ModelScope/Skywork-Reward-Llama-3.1-8B-v0.2 \ + --teacher_model_group Qwen/Qwen3-32B AI-ModelScope/Skywork-Reward-Llama-3.1-8B-v0.2 \ --use_mopd true \ --tuner_type full \ --dataset open-thoughts/OpenThoughts3-1.2M#10000 \ @@ -27,5 +27,4 @@ swift rlhf \ --dataloader_num_workers 64 \ --dataset_num_proc 4 \ --deepspeed zero3 \ - --teacher_deepspeed zero3 \ - + --teacher_deepspeed zero3 diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index 55eb524905..b6bed8c102 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -139,7 +139,7 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non else: self.teacher_model = None # Initialize teacher model group (for MOPD) - if self.teacher_model_group is not None and len(self.teacher_model_group) > 0: + if self.use_mopd and self.teacher_model_group is not None and len(self.teacher_model_group) > 0: prepared_models = [] for model_name in self.teacher_model_group: if self.is_deepspeed_enabled: @@ -158,6 +158,38 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non self.offload_model(self.accelerator.unwrap_model(prepared_model)) prepared_models.append(prepared_model) self.teacher_model_group = prepared_models + from .gold_loss_adapter import GOLDLossAdapter + # Initialize student tokenizer (only once) + if not hasattr(self, 'student_tokenizer'): + student_model_path = self.get_model_path(model) + self.student_tokenizer = AutoTokenizer.from_pretrained(student_model_path) + # Initialize teacher tokenizer (only once) + if not hasattr(self, 'teacher_tokenizer_group'): + self.teacher_tokenizer_group = {} + for teacher_model in self.teacher_model_group: + adapter_key = id(teacher_model) + if adapter_key not in self.teacher_tokenizer_group: + teacher_model_path = self.get_model_path(teacher_model) + self.teacher_tokenizer_group[adapter_key] = AutoTokenizer.from_pretrained(teacher_model_path) + # Initialize adapter only once + if not hasattr(self, 'gold_adapter_group'): + self.gold_adapter_group = {} + for teacher_tokenizer in self.teacher_tokenizer_group.values(): + adapter_key = id(teacher_tokenizer) + if adapter_key not in self.gold_adapter_group: + self.gold_adapter_group[adapter_key] = GOLDLossAdapter( + config={ + 'use_uld_loss': True, + 'use_extended_uld': True, + 'uld_use_hybrid_loss': True, + 'uld_crossentropy_weight': 0.0, + 'uld_distillation_weight': 1.0, + 'uld_student_temperature': 1.0, + 'uld_teacher_temperature': 1.0, + }, + student_tokenizer=self.student_tokenizer, + teacher_tokenizer=teacher_tokenizer, + ) else: self.teacher_model_group = None @@ -501,103 +533,69 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N loss = loss + self.args.sft_alpha * outputs_student.loss # Separate teacher model provided elif self.use_mopd: - if self.teacher_model_group is None: - # Use single teacher model - teacher_model_group = [self.teacher_model] - else: - teacher_model_group = self.teacher_model_group - # 预先计算教师模型数量,避免除零错误和重复计算 - num_teacher_models = len(teacher_model_group) + num_teacher_models = len(self.teacher_model_group) if num_teacher_models == 0: raise ValueError("teacher_model_group cannot be empty") loss = torch.tensor(0.0, device=model.device) + prompt_texts = inputs['prompt_text'] + completion_texts = inputs['completion_texts'] + ( + student_input_ids, + student_labels, + student_attention_mask, + student_prompt_length, + ) = self.build_inputs_from_texts( + self.student_tokenizer, + prompt_texts, + completion_texts + ) + # Student model forward pass (WITH gradients for student parameters) + outputs_student = model( + input_ids=student_input_ids, + attention_mask=student_attention_mask, + ) for teacher_model in teacher_model_group: - if not hasattr(self, 'student_tokenizer'): - student_model = model - student_model_path = self.get_model_path(student_model) - self.student_tokenizer = AutoTokenizer.from_pretrained(student_model_path) - # Initialize teacher tokenizer (only once) - if not hasattr(self, 'teacher_tokenizer'): - teacher_model_path = self.get_model_path(teacher_model) - - self.teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_path) - - from .gold_loss_adapter import GOLDLossAdapter + print('-------self.use_generalized_jsd_loss') + teacher_tokenizer = self.teacher_tokenizer_group[id(teacher_model)] + # Add teacher model memory management like in liger branch + load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() - # Initialize adapter only once - if not hasattr(self, 'gold_adapter'): - self.gold_adapter = GOLDLossAdapter( - config={ - 'use_uld_loss': True, - 'use_extended_uld': True, - 'uld_use_hybrid_loss': True, - 'uld_crossentropy_weight': 0.0, - 'uld_distillation_weight': 1.0, - 'uld_student_temperature': 1.0, - 'uld_teacher_temperature': 1.0, - }, - student_tokenizer=self.student_tokenizer, - teacher_tokenizer=self.teacher_tokenizer, + # Get adapter for current teacher tokenizer + adapter_key = id(teacher_tokenizer) + gold_adapter = self.gold_adapter_group[adapter_key] + with load_context: + with torch.no_grad(), disable_gradient_checkpointing(teacher_model, + self.args.gradient_checkpointing_kwargs): + ( + teacher_input_ids, + teacher_labels, + teacher_attention_mask, + teacher_prompt_length, + ) = self.build_inputs_from_texts( + teacher_tokenizer, + prompt_texts, + completion_texts ) - prompt_texts = inputs['prompt_text'] - completion_texts = inputs['completion_texts'] - - # Add teacher model memory management like in liger branch - load_context = self.load_teacher_model_context() if self.args.offload_teacher_model \ - else nullcontext() - with load_context: - with torch.no_grad(), disable_gradient_checkpointing(teacher_model, - self.args.gradient_checkpointing_kwargs): - ( - teacher_input_ids, - teacher_labels, - teacher_attention_mask, - teacher_prompt_length, - ) = self.build_inputs_from_texts( - self.teacher_tokenizer, - prompt_texts, - completion_texts - ) - ( - student_input_ids, - student_labels, - student_attention_mask, - student_prompt_length, - ) = self.build_inputs_from_texts( - self.student_tokenizer, - prompt_texts, - completion_texts - ) - - # Teacher model forward pass (NO gradients) - outputs_teacher = teacher_model( - input_ids=teacher_input_ids, - attention_mask=teacher_attention_mask, - ) - # Student model forward pass (WITH gradients for student parameters) - outputs_student = model( - input_ids=student_input_ids, - attention_mask=student_attention_mask, - ) - - # Ensure teacher_logits has gradient info but teacher model params don't participate - teacher_logits = outputs_teacher.logits.detach().requires_grad_(True) - - # Release intermediate tensors to free memory - del teacher_attention_mask, student_attention_mask - - loss_item = self.gold_adapter( - student_logits=outputs_student.logits, - teacher_logits=teacher_logits, - student_labels=student_labels, - teacher_labels=teacher_labels, - student_input_ids=student_input_ids, - teacher_input_ids=teacher_input_ids, - ) - - loss = loss + loss_item / num_teacher_models - # ... existing code ... + # Teacher model forward pass (NO gradients) + outputs_teacher = teacher_model( + input_ids=teacher_input_ids, + attention_mask=teacher_attention_mask, + ) + # Ensure teacher_logits has gradient info but teacher model params don't participate + teacher_logits = outputs_teacher.logits.detach().requires_grad_(True) + + # Release intermediate tensors to free memory + del teacher_attention_mask + loss_total = gold_adapter( + student_logits=outputs_student.logits, + teacher_logits=teacher_logits, + student_labels=student_labels, + teacher_labels=teacher_labels, + student_input_ids=student_input_ids, + teacher_input_ids=teacher_input_ids, + ) + loss += loss_total / len(teacher_model_group) # Separate teacher model provided else: assert self.teacher_model is not None From 1d2cd01ec0e18744754d2740cbe4a3ca00809aec Mon Sep 17 00:00:00 2001 From: Chuming Yao <1416004356qq@gmail.com> Date: Tue, 19 May 2026 15:42:30 +0800 Subject: [PATCH 12/13] =?UTF-8?q?[MOPD]=20=E6=BA=90=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E5=A3=B0=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- swift/rlhf_trainers/gkd_trainer.py | 7 ++++--- swift/rlhf_trainers/gold_loss_adapter.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index b6bed8c102..947a93989d 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -554,8 +554,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N input_ids=student_input_ids, attention_mask=student_attention_mask, ) - for teacher_model in teacher_model_group: - print('-------self.use_generalized_jsd_loss') + for teacher_model in self.teacher_model_group: teacher_tokenizer = self.teacher_tokenizer_group[id(teacher_model)] # Add teacher model memory management like in liger branch load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() @@ -587,6 +586,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # Release intermediate tensors to free memory del teacher_attention_mask + + # trl/experimental/gold/gold_trainer.py loss_total = gold_adapter( student_logits=outputs_student.logits, teacher_logits=teacher_logits, @@ -595,7 +596,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N student_input_ids=student_input_ids, teacher_input_ids=teacher_input_ids, ) - loss += loss_total / len(teacher_model_group) + loss += loss_total / len(self.teacher_model_group) # Separate teacher model provided else: assert self.teacher_model is not None diff --git a/swift/rlhf_trainers/gold_loss_adapter.py b/swift/rlhf_trainers/gold_loss_adapter.py index 55d183cd82..e47e7dddd6 100644 --- a/swift/rlhf_trainers/gold_loss_adapter.py +++ b/swift/rlhf_trainers/gold_loss_adapter.py @@ -1,3 +1,16 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import List, Optional, Tuple import torch From 479ddf77378e0174e7975f693b7787df9c5a9ae9 Mon Sep 17 00:00:00 2001 From: Chuming Yao <1416004356qq@gmail.com> Date: Tue, 19 May 2026 16:57:58 +0800 Subject: [PATCH 13/13] [MOPD] pre-commit --- swift/rlhf_trainers/gkd_trainer.py | 31 ++++------- swift/rlhf_trainers/gold_loss_adapter.py | 70 ++++++++---------------- 2 files changed, 34 insertions(+), 67 deletions(-) diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index 947a93989d..ee52f56575 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -159,6 +159,7 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non prepared_models.append(prepared_model) self.teacher_model_group = prepared_models from .gold_loss_adapter import GOLDLossAdapter + # Initialize student tokenizer (only once) if not hasattr(self, 'student_tokenizer'): student_model_path = self.get_model_path(model) @@ -535,7 +536,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N elif self.use_mopd: num_teacher_models = len(self.teacher_model_group) if num_teacher_models == 0: - raise ValueError("teacher_model_group cannot be empty") + raise ValueError('teacher_model_group cannot be empty') loss = torch.tensor(0.0, device=model.device) prompt_texts = inputs['prompt_text'] completion_texts = inputs['completion_texts'] @@ -544,11 +545,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N student_labels, student_attention_mask, student_prompt_length, - ) = self.build_inputs_from_texts( - self.student_tokenizer, - prompt_texts, - completion_texts - ) + ) = self.build_inputs_from_texts(self.student_tokenizer, prompt_texts, completion_texts) # Student model forward pass (WITH gradients for student parameters) outputs_student = model( input_ids=student_input_ids, @@ -570,11 +567,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N teacher_labels, teacher_attention_mask, teacher_prompt_length, - ) = self.build_inputs_from_texts( - teacher_tokenizer, - prompt_texts, - completion_texts - ) + ) = self.build_inputs_from_texts(teacher_tokenizer, prompt_texts, completion_texts) # Teacher model forward pass (NO gradients) outputs_teacher = teacher_model( @@ -638,8 +631,8 @@ def build_inputs_from_texts( pad_token_id = tokenizer.pad_token_id eos_token_id = tokenizer.eos_token_id - prompt_token_ids = tokenizer(prompt_texts, add_special_tokens=True)["input_ids"] - completion_token_ids = tokenizer(completion_texts, add_special_tokens=False)["input_ids"] + prompt_token_ids = tokenizer(prompt_texts, add_special_tokens=True)['input_ids'] + completion_token_ids = tokenizer(completion_texts, add_special_tokens=False)['input_ids'] sequences: list[torch.Tensor] = [] attention_masks: list[torch.Tensor] = [] @@ -674,26 +667,24 @@ def build_inputs_from_texts( sequences.append(seq_tensor) attention_masks.append(torch.ones_like(seq_tensor)) labels = seq_tensor.clone() - labels[: len(prompt_ids)] = -100 + labels[:len(prompt_ids)] = -100 if pad_token_id is not None: labels[labels == pad_token_id] = -100 labels_list.append(labels) teacher_input_ids = pad( sequences, - padding_side="right", + padding_side='right', padding_value=pad_token_id if pad_token_id is not None else 0, ) - teacher_attention_mask = pad(attention_masks, padding_side="right", padding_value=0).bool() - teacher_labels = pad(labels_list, padding_side="right", padding_value=-100) + teacher_attention_mask = pad(attention_masks, padding_side='right', padding_value=0).bool() + teacher_labels = pad(labels_list, padding_side='right', padding_value=-100) if eos_token_id is not None: for row in range(teacher_attention_mask.size(0)): valid = ( teacher_input_ids[row] != pad_token_id - if pad_token_id is not None - else teacher_attention_mask[row].bool() - ) + if pad_token_id is not None else teacher_attention_mask[row].bool()) if valid.any(): last_idx = valid.nonzero(as_tuple=True)[0][-1] teacher_attention_mask[row, last_idx + 1:] = False diff --git a/swift/rlhf_trainers/gold_loss_adapter.py b/swift/rlhf_trainers/gold_loss_adapter.py index e47e7dddd6..c7948ca984 100644 --- a/swift/rlhf_trainers/gold_loss_adapter.py +++ b/swift/rlhf_trainers/gold_loss_adapter.py @@ -120,9 +120,7 @@ def _initialize_vocabulary_mapping(self): if self._vocab_mapping: max_matched_teacher_id = max(self._vocab_mapping.keys()) - self.mapping_tensor = torch.full( - (max_matched_teacher_id + 1,), -1, dtype=torch.long - ) + self.mapping_tensor = torch.full((max_matched_teacher_id + 1,), -1, dtype=torch.long) for k, v in self._vocab_mapping.items(): self.mapping_tensor[k] = v if self.device: @@ -159,18 +157,11 @@ def forward( crossentropy_loss = self._compute_cross_entropy(student_logits, student_labels) # 2. Distillation loss (ULD) - distillation_loss = self._compute_distillation_loss( - student_logits, teacher_logits, - student_labels, teacher_labels, - student_input_ids, teacher_input_ids - ) + distillation_loss = self._compute_distillation_loss(student_logits, teacher_logits, student_labels, + teacher_labels, student_input_ids, teacher_input_ids) return crossentropy_loss + distillation_loss - def _compute_cross_entropy( - self, - student_logits: torch.Tensor, - student_labels: torch.Tensor - ) -> torch.Tensor: + def _compute_cross_entropy(self, student_logits: torch.Tensor, student_labels: torch.Tensor) -> torch.Tensor: """计算cross-entropy loss""" if self.crossentropy_weight <= 0: return torch.tensor(0.0, device=student_logits.device, requires_grad=True) @@ -179,10 +170,7 @@ def _compute_cross_entropy( shift_labels = student_labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) - ce_loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1) - ) + ce_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return self.crossentropy_weight * ce_loss def _compute_distillation_loss( @@ -238,15 +226,12 @@ def _compute_distillation_loss( # Token对齐 if self.use_extended_uld: student_groups, teacher_groups = self._build_alignment_groups_from_ids( - student_token_ids, teacher_token_ids - ) + student_token_ids, teacher_token_ids) - student_aligned = self._merge_probabilities_with_groups( - student_probs, student_groups, student_token_ids - ) - teacher_aligned = self._merge_probabilities_with_groups( - teacher_probs, teacher_groups, teacher_token_ids - ) + student_aligned = self._merge_probabilities_with_groups(student_probs, student_groups, + student_token_ids) + teacher_aligned = self._merge_probabilities_with_groups(teacher_probs, teacher_groups, + teacher_token_ids) else: min_len = min(len(student_token_ids), len(teacher_token_ids)) @@ -281,11 +266,8 @@ def _get_answer_regions(self, labels: torch.Tensor) -> Tuple[List[int], List[int return indices, sizes - def _build_alignment_groups_from_ids( - self, - student_token_ids: List[int], - teacher_token_ids: List[int] - ) -> Tuple[List[List[int]], List[List[int]]]: + def _build_alignment_groups_from_ids(self, student_token_ids: List[int], + teacher_token_ids: List[int]) -> Tuple[List[List[int]], List[List[int]]]: """ 基于文本内容构建对齐组 使用贪心子串匹配算法 @@ -293,7 +275,7 @@ def _build_alignment_groups_from_ids( def decode_tokens(tokenizer, token_ids): pieces = [] - prev = "" + prev = '' for k in range(len(token_ids)): cur = tokenizer.decode(token_ids[:k + 1], skip_special_tokens=False) pieces.append(cur[len(prev):]) @@ -310,8 +292,8 @@ def decode_tokens(tokenizer, token_ids): t_idx = 0 while s_idx < len(student_pieces) and t_idx < len(teacher_pieces): - student_text = "" - teacher_text = "" + student_text = '' + teacher_text = '' student_group = [] teacher_group = [] @@ -408,7 +390,7 @@ def _compute_basic_uld_loss( if t_vocab < max_vocab: teacher_sorted = F.pad(teacher_sorted, (0, max_vocab - t_vocab)) - loss = F.l1_loss(student_sorted, teacher_sorted, reduction="sum") + loss = F.l1_loss(student_sorted, teacher_sorted, reduction='sum') loss /= student_aligned.size(0) return loss @@ -425,9 +407,7 @@ def _compute_hybrid_uld_loss( # 创建matched/unmatched masks if self._teacher_matched_ids: - teacher_matched_idx = torch.tensor( - sorted(self._teacher_matched_ids), dtype=torch.long, device=device - ) + teacher_matched_idx = torch.tensor(sorted(self._teacher_matched_ids), dtype=torch.long, device=device) student_matched_idx = self.mapping_tensor[teacher_matched_idx] else: teacher_matched_idx = torch.tensor([], dtype=torch.long, device=device) @@ -449,9 +429,7 @@ def _compute_hybrid_uld_loss( student_matched_probs = student_aligned[:, student_matched_idx] matched_count = teacher_matched_probs.size(-1) - matched_loss = self._compute_jsd_for_matched( - student_matched_probs, teacher_matched_probs - ) + matched_loss = self._compute_jsd_for_matched(student_matched_probs, teacher_matched_probs) # 2. Unmatched tokens的排序L1损失 teacher_unmatched = teacher_aligned[:, ~teacher_matched_mask] student_unmatched = student_aligned[:, ~student_matched_mask] @@ -470,7 +448,7 @@ def _compute_hybrid_uld_loss( if s_size < max_size: student_sorted = F.pad(student_sorted, (0, max_size - s_size)) - unmatched_loss = F.l1_loss(student_sorted, teacher_sorted, reduction="sum") + unmatched_loss = F.l1_loss(student_sorted, teacher_sorted, reduction='sum') unmatched_loss /= student_aligned.size(0) # 3. 加权组合 @@ -489,12 +467,10 @@ def _compute_hybrid_uld_loss( return total_loss - def _compute_jsd_for_matched( - self, - student_probs: torch.Tensor, - teacher_probs: torch.Tensor, - epsilon: float = 1e-8 - ) -> torch.Tensor: + def _compute_jsd_for_matched(self, + student_probs: torch.Tensor, + teacher_probs: torch.Tensor, + epsilon: float = 1e-8) -> torch.Tensor: """计算matched tokens的JSD损失,添加数值稳定性处理""" batch_seq_len, num_matched = student_probs.shape