diff --git a/examples/train/rlhf/mopd/mopd.sh b/examples/train/rlhf/mopd/mopd.sh new file mode 100644 index 0000000000..8013bf6b48 --- /dev/null +++ b/examples/train/rlhf/mopd/mopd.sh @@ -0,0 +1,30 @@ +NPROC_PER_NODE=7 \ +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 \ +swift rlhf \ + --rlhf_type gkd \ + --model Qwen/Qwen3-8B-Base \ + --teacher_model_group Qwen/Qwen3-32B AI-ModelScope/Skywork-Reward-Llama-3.1-8B-v0.2 \ + --use_mopd true \ + --tuner_type full \ + --dataset open-thoughts/OpenThoughts3-1.2M#10000 \ + --seq_kd false \ + --lmbda 1 \ + --beta 1 \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-5 \ + --gradient_accumulation_steps 1 \ + --save_steps 1000 \ + --save_total_limit 2 \ + --logging_steps 1 \ + --max_length 16000 \ + --max_completion_length 8192 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --save_only_model true \ + --dataloader_num_workers 64 \ + --dataset_num_proc 4 \ + --deepspeed zero3 \ + --teacher_deepspeed zero3 diff --git a/swift/arguments/rlhf_args.py b/swift/arguments/rlhf_args.py index aa88fca008..0fb60b9a14 100644 --- a/swift/arguments/rlhf_args.py +++ b/swift/arguments/rlhf_args.py @@ -65,6 +65,9 @@ class TeacherModelArguments: remotely. When this is set, `teacher_model` is not required. Defaults to None. """ teacher_model: Optional[str] = None + teacher_model_group: List[str] = field(default_factory=list) + # todo 还需要增加mopd_config + use_mopd: bool = False teacher_adapters: List[str] = field(default_factory=list) teacher_model_type: Optional[str] = field( default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'}) @@ -73,15 +76,15 @@ class TeacherModelArguments: default=None, metadata={ 'help': - 'DeepSpeed configuration for teacher model. ' - 'Can be a path to a json file or one of: zero0, zero1, zero2, zero3, zero2_offload, zero3_offload' + 'DeepSpeed configuration for teacher model. ' + 'Can be a path to a json file or one of: zero0, zero1, zero2, zero3, zero2_offload, zero3_offload' }) teacher_model_server: Optional[str] = field( default=None, metadata={ 'help': - 'URL of the teacher model server (e.g., http://localhost:8000). ' - 'When set, teacher logprobs are fetched via API instead of loading a local model.' + 'URL of the teacher model server (e.g., http://localhost:8000). ' + 'When set, teacher logprobs are fetched via API instead of loading a local model.' }) diff --git a/swift/pipelines/train/rlhf.py b/swift/pipelines/train/rlhf.py index fd9c16b2ea..6e70606238 100644 --- a/swift/pipelines/train/rlhf.py +++ b/swift/pipelines/train/rlhf.py @@ -132,6 +132,24 @@ def _prepare_model_tokenizer(self): model, _ = result setattr(self, f'{key}_model', model) + # Handle teacher_model_group for GKD + self.teacher_model_group_models = None + if args.rlhf_type == 'gkd' and hasattr(args, 'teacher_model_group') and args.teacher_model_group: + logger.info(f'Loading teacher_model_group with {len(args.teacher_model_group)} models') + self.teacher_model_group_models = [] + for idx, teacher_model_path in enumerate(args.teacher_model_group): + logger.info(f'Loading teacher model group [{idx}]: {teacher_model_path}') + # Use teacher_model_type and teacher_model_revision if available, otherwise infer + model_type = getattr(args, 'teacher_model_type', None) + model_revision = getattr(args, 'teacher_model_revision', None) + + result = self._prepare_single_model_for_teacher_group(teacher_model_path, model_type, model_revision) + if result is not None: + model, _ = result + self.teacher_model_group_models.append(model) + logger.info(f'Successfully loaded teacher model group [{idx}]: {model}') + logger.info(f'Total teacher_model_group_models loaded: {len(self.teacher_model_group_models)}') + # Handle reward model(s) self.reward_model = None if hasattr(args, 'reward_model') and args.reward_model is not None: @@ -166,6 +184,44 @@ def _prepare_model_tokenizer(self): super()._prepare_model_tokenizer() + def _prepare_single_model_for_teacher_group(self, model_id_or_path, model_type, model_revision): + """Prepare a single model for teacher_model_group.""" + args = self.args + + if model_type is None: + model_info, _ = get_model_info_meta(model_id_or_path) + model_type = model_info.model_type + + model_dir = safe_snapshot_download( + model_id_or_path=model_id_or_path, + revision=model_revision, + download_model=False, + use_hf=args.use_hf, + hub_token=args.hub_token, + ) + task_type, num_labels = self._get_model_task_type(model_dir) + + context = nullcontext() + if args.teacher_deepspeed: + if args.teacher_deepspeed.get('zero_optimization', {}).get('stage') != 3: + context = disable_deepspeed_zero3() + with context: + model, processor = args.get_model_processor( + model=model_id_or_path, + model_type=model_type, + revision=model_revision, + task_type=task_type, + num_labels=num_labels) + + # For teacher models, set to eval mode and disable gradients + if self.args.sequence_parallel_size > 1: + sequence_parallel.prepare( + self.args.sequence_parallel_size, model, processor, padding_free=args.padding_free) + model.requires_grad_(False).eval() + + HfConfigFactory.set_config_attr(model.config, 'use_cache', False) + return model, processor + @classmethod def prepare_model(cls, args, model, *, template=None, train_dataset=None, task_type=None): model = super().prepare_model(args, model, template=template, train_dataset=train_dataset, task_type=task_type) @@ -238,7 +294,16 @@ def _get_trainer_kwargs(self): trainer_kwargs['gkd_logits_topk'] = self.args.gkd_logits_topk if self.args.teacher_model_server: trainer_kwargs['teacher_model_server'] = self.args.teacher_model_server + # Pass pre-loaded teacher_model_group_models if available, otherwise pass the string list + if hasattr(self, 'teacher_model_group_models') and self.teacher_model_group_models: + trainer_kwargs['teacher_model_group_models'] = self.teacher_model_group_models + else: + trainer_kwargs['teacher_model_group'] = self.args.teacher_model_group trainer_kwargs['teacher_use_disable_adapter'] = getattr(self.args, '_teacher_use_disable_adapter', False) + if self.args.use_mopd: + trainer_kwargs['use_mopd'] = self.args.use_mopd + # todo + # trainer_kwargs['mopd_config'] = self.args.mopd_config return trainer_kwargs diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index 8256c0d5cf..ee52f56575 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -3,21 +3,23 @@ import os import random import re -import torch -import torch.nn as nn -import torch.nn.functional as F -import trl -from accelerate.utils import gather_object, is_peft_model from collections import defaultdict, deque from contextlib import contextmanager, nullcontext from copy import deepcopy from dataclasses import dataclass from enum import Enum +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import trl +from accelerate.utils import gather_object, is_peft_model from packaging import version -from transformers import PreTrainedModel +from transformers import AutoTokenizer, PreTrainedModel +from transformers.tokenization_utils_base import PreTrainedTokenizerBase from trl import SFTTrainer as HFSFTTrainer -from trl.trainer.utils import RepeatSampler -from typing import Dict, Optional, Union +from trl.trainer.utils import RepeatSampler, pad from swift.infer_engine.protocol import MultiModalRequestMixin from swift.template import TemplateInputs @@ -30,6 +32,7 @@ try: from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss + _liger_kernel_available = True except ImportError: _liger_kernel_available = False @@ -82,7 +85,9 @@ class GKDTrainer(RolloutTrainerMixin, SwiftMixin, HFGKDTrainer): def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, *_args, **kwargs): teacher_model = kwargs.pop('teacher_model', None) + self.teacher_model_group = kwargs.pop('teacher_model_group', None) teacher_deepspeed_config = kwargs.pop('teacher_deepspeed_config', None) + self.use_mopd = kwargs.pop('use_mopd', False) self.vllm_client = kwargs.pop('vllm_client', None) self.gkd_logits_topk = kwargs.pop('gkd_logits_topk', None) teacher_model_server = kwargs.pop('teacher_model_server', None) @@ -133,6 +138,61 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non self.offload_model(self.accelerator.unwrap_model(self.teacher_model)) else: self.teacher_model = None + # Initialize teacher model group (for MOPD) + if self.use_mopd and self.teacher_model_group is not None and len(self.teacher_model_group) > 0: + prepared_models = [] + for model_name in self.teacher_model_group: + if self.is_deepspeed_enabled: + if teacher_deepspeed_config is not None: + prepared_model = prepare_deepspeed( + model_name, self.accelerator, deepspeed_config=teacher_deepspeed_config, training_args=args) + else: + prepared_model = prepare_deepspeed(model_name, self.accelerator) + elif self.is_fsdp_enabled: + from .utils import prepare_fsdp + prepared_model = prepare_fsdp(model_name, self.accelerator) + else: + prepared_model = self.accelerator.prepare_model(model_name, evaluation_mode=True) + prepared_model.eval() + if self.args.offload_teacher_model: + self.offload_model(self.accelerator.unwrap_model(prepared_model)) + prepared_models.append(prepared_model) + self.teacher_model_group = prepared_models + from .gold_loss_adapter import GOLDLossAdapter + + # Initialize student tokenizer (only once) + if not hasattr(self, 'student_tokenizer'): + student_model_path = self.get_model_path(model) + self.student_tokenizer = AutoTokenizer.from_pretrained(student_model_path) + # Initialize teacher tokenizer (only once) + if not hasattr(self, 'teacher_tokenizer_group'): + self.teacher_tokenizer_group = {} + for teacher_model in self.teacher_model_group: + adapter_key = id(teacher_model) + if adapter_key not in self.teacher_tokenizer_group: + teacher_model_path = self.get_model_path(teacher_model) + self.teacher_tokenizer_group[adapter_key] = AutoTokenizer.from_pretrained(teacher_model_path) + # Initialize adapter only once + if not hasattr(self, 'gold_adapter_group'): + self.gold_adapter_group = {} + for teacher_tokenizer in self.teacher_tokenizer_group.values(): + adapter_key = id(teacher_tokenizer) + if adapter_key not in self.gold_adapter_group: + self.gold_adapter_group[adapter_key] = GOLDLossAdapter( + config={ + 'use_uld_loss': True, + 'use_extended_uld': True, + 'uld_use_hybrid_loss': True, + 'uld_crossentropy_weight': 0.0, + 'uld_distillation_weight': 1.0, + 'uld_student_temperature': 1.0, + 'uld_teacher_temperature': 1.0, + }, + student_tokenizer=self.student_tokenizer, + teacher_tokenizer=teacher_tokenizer, + ) + else: + self.teacher_model_group = None # Initialize rollout infrastructure for vLLM support self.prepare_rollout() @@ -169,7 +229,7 @@ def get_train_dataloader(self): def _build_opsd_teacher_data(self, inputs): """Build teacher data for OPSD by replacing the last user message with teacher_prompt. - Returns None if teacher_prompt is not available in all examples. + Returns None if teacher_prompt is not av ailable in all examples. """ if not all('teacher_prompt' in data and data['teacher_prompt'] for data in inputs): return None @@ -299,7 +359,33 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token new_position_ids = new_attention_mask.cumsum(dim=1) - 1 new_position_ids[new_position_ids < 0] = 0 inputs['position_ids'] = new_position_ids - return generated_tokens, new_attention_mask, new_labels + if self.use_mopd: + # 返回解码后的文本encoded_inputs + batch_size = generated_tokens.shape[0] + prompt_length = prompt_input_ids.shape[1] + + completion_texts = [] + pad_token_id = self.processing_class.pad_token_id + eos_token_id = self.processing_class.eos_token_id + for i in range(batch_size): + # Decode completion + completion_ids = generated_tokens[i][prompt_length:].tolist() + # 截断到第一个 EOS 或 PAD token + cleaned_ids = [] + for token_id in completion_ids: + if token_id == eos_token_id or token_id == pad_token_id: + break + cleaned_ids.append(token_id) + + # 解码清理后的 token IDs + completion_text = self.template.safe_decode(cleaned_ids) + + completion_text = completion_text.strip() + + completion_texts.append(completion_text) + else: + completion_texts = None + return generated_tokens, new_attention_mask, new_labels, completion_texts @profiling_decorator def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): @@ -447,6 +533,64 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT: loss = loss + self.args.sft_alpha * outputs_student.loss # Separate teacher model provided + elif self.use_mopd: + num_teacher_models = len(self.teacher_model_group) + if num_teacher_models == 0: + raise ValueError('teacher_model_group cannot be empty') + loss = torch.tensor(0.0, device=model.device) + prompt_texts = inputs['prompt_text'] + completion_texts = inputs['completion_texts'] + ( + student_input_ids, + student_labels, + student_attention_mask, + student_prompt_length, + ) = self.build_inputs_from_texts(self.student_tokenizer, prompt_texts, completion_texts) + # Student model forward pass (WITH gradients for student parameters) + outputs_student = model( + input_ids=student_input_ids, + attention_mask=student_attention_mask, + ) + for teacher_model in self.teacher_model_group: + teacher_tokenizer = self.teacher_tokenizer_group[id(teacher_model)] + # Add teacher model memory management like in liger branch + load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() + + # Get adapter for current teacher tokenizer + adapter_key = id(teacher_tokenizer) + gold_adapter = self.gold_adapter_group[adapter_key] + with load_context: + with torch.no_grad(), disable_gradient_checkpointing(teacher_model, + self.args.gradient_checkpointing_kwargs): + ( + teacher_input_ids, + teacher_labels, + teacher_attention_mask, + teacher_prompt_length, + ) = self.build_inputs_from_texts(teacher_tokenizer, prompt_texts, completion_texts) + + # Teacher model forward pass (NO gradients) + outputs_teacher = teacher_model( + input_ids=teacher_input_ids, + attention_mask=teacher_attention_mask, + ) + # Ensure teacher_logits has gradient info but teacher model params don't participate + teacher_logits = outputs_teacher.logits.detach().requires_grad_(True) + + # Release intermediate tensors to free memory + del teacher_attention_mask + + # trl/experimental/gold/gold_trainer.py + loss_total = gold_adapter( + student_logits=outputs_student.logits, + teacher_logits=teacher_logits, + student_labels=student_labels, + teacher_labels=teacher_labels, + student_input_ids=student_input_ids, + teacher_input_ids=teacher_input_ids, + ) + loss += loss_total / len(self.teacher_model_group) + # Separate teacher model provided else: assert self.teacher_model is not None if self.args.sft_alpha > 0: @@ -470,11 +614,112 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT: loss = loss + self.args.sft_alpha * outputs_student.loss - # Return loss - if return_outputs: - return (loss, outputs_student) - else: - return loss + # Return loss + if return_outputs: + return (loss, outputs_student) + else: + return loss + + def build_inputs_from_texts( + self, + tokenizer: PreTrainedTokenizerBase, + prompt_texts: list[str], + completion_texts: list[str], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + """Tokenize teacher prompts/completions and produce tensors ready for GOLD loss.""" + + pad_token_id = tokenizer.pad_token_id + eos_token_id = tokenizer.eos_token_id + + prompt_token_ids = tokenizer(prompt_texts, add_special_tokens=True)['input_ids'] + completion_token_ids = tokenizer(completion_texts, add_special_tokens=False)['input_ids'] + + sequences: list[torch.Tensor] = [] + attention_masks: list[torch.Tensor] = [] + labels_list: list[torch.Tensor] = [] + prompt_lengths: list[int] = [] + # Get device using reliable detection method + device = None + try: + # First try to get device from model parameters + if hasattr(self, 'model') and self.model is not None: + device = next(self.model.parameters()).device + elif hasattr(self, 'teacher_model') and self.teacher_model is not None: + device = next(self.teacher_model.parameters()).device + except (AttributeError, StopIteration): + pass + + # Fallback to default device detection + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + for prompt_ids, completion_ids in zip(prompt_token_ids, completion_token_ids, strict=True): + # Remove trailing EOS from prompt so completions can extend cleanly + if eos_token_id is not None and prompt_ids and prompt_ids[-1] == eos_token_id: + prompt_ids = prompt_ids[:-1] + + prompt_lengths.append(len(prompt_ids)) + sequence = list(prompt_ids) + sequence.extend(completion_ids) + if eos_token_id is not None: + sequence.append(eos_token_id) + + seq_tensor = torch.tensor(sequence, dtype=torch.long, device=device) + sequences.append(seq_tensor) + attention_masks.append(torch.ones_like(seq_tensor)) + labels = seq_tensor.clone() + labels[:len(prompt_ids)] = -100 + if pad_token_id is not None: + labels[labels == pad_token_id] = -100 + labels_list.append(labels) + + teacher_input_ids = pad( + sequences, + padding_side='right', + padding_value=pad_token_id if pad_token_id is not None else 0, + ) + teacher_attention_mask = pad(attention_masks, padding_side='right', padding_value=0).bool() + teacher_labels = pad(labels_list, padding_side='right', padding_value=-100) + + if eos_token_id is not None: + for row in range(teacher_attention_mask.size(0)): + valid = ( + teacher_input_ids[row] != pad_token_id + if pad_token_id is not None else teacher_attention_mask[row].bool()) + if valid.any(): + last_idx = valid.nonzero(as_tuple=True)[0][-1] + teacher_attention_mask[row, last_idx + 1:] = False + + teacher_prompt_length = max(prompt_lengths) if prompt_lengths else 0 + + return teacher_input_ids, teacher_labels, teacher_attention_mask, teacher_prompt_length + + def get_model_path(self, model): + model_path = getattr(model, 'name_or_path', None) + if model_path is None: + # Try to get path from config + if hasattr(model, 'config') and hasattr(model.config, '_name_or_path'): + model_path = model.config._name_or_path + if model_path is None: + # If still None, try to get from model's base model + unwrapped_student = self.accelerator.unwrap_model(model) + if hasattr(unwrapped_student, 'base_model_prefix'): + base_model = getattr(unwrapped_student, unwrapped_student.base_model_prefix, unwrapped_student) + if hasattr(base_model, 'config') and hasattr(base_model.config, '_name_or_path'): + model_path = base_model.config._name_or_path + # Additional fallback: try to get from model's config name_or_path attribute + if model_path is None: + if hasattr(model, 'config') and hasattr(model.config, 'name_or_path'): + model_path = model.config.name_or_path + # Additional fallback: try to get from unwrapped model's name_or_path + if model_path is None: + unwrapped_student = self.accelerator.unwrap_model(model) + model_path = getattr(unwrapped_student, 'name_or_path', None) + # Additional fallback: try to get from unwrapped model's config name_or_path + if model_path is None: + unwrapped_student = self.accelerator.unwrap_model(model) + if hasattr(unwrapped_student, 'config') and hasattr(unwrapped_student.config, 'name_or_path'): + model_path = unwrapped_student.config.name_or_path + return model_path def _prepare_batch_inputs(self, inputs: list, encode_prompt_only: bool = False) -> Dict[str, torch.Tensor]: """Prepare batch inputs for training. @@ -913,16 +1158,16 @@ def _align_vocab_size(student_logits, teacher_logits): return student_logits, teacher_logits def generalized_jsd_loss( - self, - student_logits, - teacher_logits=None, - labels=None, - beta=0.5, - temperature=1.0, - chunk_size=512, - topk=None, - teacher_topk_logprobs=None, - teacher_topk_indices=None, + self, + student_logits, + teacher_logits=None, + labels=None, + beta=0.5, + temperature=1.0, + chunk_size=512, + topk=None, + teacher_topk_logprobs=None, + teacher_topk_indices=None, ): # Align vocab sizes when student and teacher have different vocabulary dimensions if teacher_logits is not None: diff --git a/swift/rlhf_trainers/gold_loss_adapter.py b/swift/rlhf_trainers/gold_loss_adapter.py new file mode 100644 index 0000000000..c7948ca984 --- /dev/null +++ b/swift/rlhf_trainers/gold_loss_adapter.py @@ -0,0 +1,513 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PreTrainedTokenizerBase + + +class GOLDLossAdapter(nn.Module): + """ + - GOLD (General Online Logit Distillation) 损失函数适配器 + 支持: + 1. ULD损失 (Universal Logit Distillation) + 2. 扩展ULD (跨tokenizer对齐) + 3. 混合损失 (Hybrid ULD + JSD) + + 使用示例: + adapter = GOLDLossAdapter( + config={ + 'use_uld_loss': True, + 'use_extended_uld': True, + 'uld_use_hybrid_loss': False, + 'uld_crossentropy_weight': 0.0, + 'uld_distillation_weight': 1.0, + 'uld_student_temperature': 1.0, + 'uld_teacher_temperature': 1.0, + }, + student_tokenizer=student_tok, + teacher_tokenizer=teacher_tok, + ) + + loss = adapter( + student_logits=student_outputs.logits, + teacher_logits=teacher_outputs.logits, + student_labels=student_labels, + teacher_labels=teacher_labels, + student_input_ids=student_input_ids, + teacher_input_ids=teacher_input_ids, + ) + """ + + def __init__( + self, + config: dict, + student_tokenizer: Optional[PreTrainedTokenizerBase] = None, + teacher_tokenizer: Optional[PreTrainedTokenizerBase] = None, + device: Optional[torch.device] = None, + ): + super().__init__() + self.device = device + + # 基础配置 + self.use_uld_loss = config.get('use_uld_loss', True) # 是否开启通用蒸馏 + self.crossentropy_weight = config.get('uld_crossentropy_weight', 0.0) + self.distillation_weight = config.get('uld_distillation_weight', 1.0) + self.student_temperature = config.get('uld_student_temperature', 0.9) + self.teacher_temperature = config.get('uld_teacher_temperature', 0.9) + self.skip_student_eos = config.get('uld_skip_student_eos', True) + self.skip_teacher_eos = config.get('uld_skip_teacher_eos', True) + self.use_extended_uld = config.get('use_extended_uld', True) + self.ignore_index = -100 + + # Tokenizers + self.student_tokenizer = student_tokenizer + self.teacher_tokenizer = teacher_tokenizer + + # Hybrid ULD配置 + self.use_hybrid_loss = config.get('uld_use_hybrid_loss', True) # 是否对完全匹配的词汇进行匹配,开启提高稳定性 + self.hybrid_matched_weight = config.get('uld_hybrid_matched_weight', None) + self.hybrid_unmatched_weight = config.get('uld_hybrid_unmatched_weight', None) + self.beta = config.get('beta', 1.0) + + # 初始化词汇映射(用于hybrid loss) + self._vocab_mapping = None + self._teacher_matched_ids = None + self._student_matched_ids = None + self.mapping_tensor = None + + if self.use_hybrid_loss and student_tokenizer and teacher_tokenizer: + self._initialize_vocabulary_mapping() + + # 用于logging + self.last_matched_loss = None + self.last_unmatched_loss = None + + def _initialize_vocabulary_mapping(self): + """初始化学生-教师tokenizer的词汇映射""" + student_vocab = self.student_tokenizer.get_vocab() + teacher_vocab = self.teacher_tokenizer.get_vocab() + + student_token_to_id = dict(student_vocab.items()) + + vocab_mapping = {} + teacher_matched_ids = set() + student_matched_ids = set() + + for token_str, teacher_id in teacher_vocab.items(): + if token_str in student_token_to_id: + student_id = student_token_to_id[token_str] + vocab_mapping[teacher_id] = student_id + teacher_matched_ids.add(teacher_id) + student_matched_ids.add(student_id) + + self._vocab_mapping = vocab_mapping + self._teacher_matched_ids = teacher_matched_ids + self._student_matched_ids = student_matched_ids + + if self._vocab_mapping: + max_matched_teacher_id = max(self._vocab_mapping.keys()) + self.mapping_tensor = torch.full((max_matched_teacher_id + 1,), -1, dtype=torch.long) + for k, v in self._vocab_mapping.items(): + self.mapping_tensor[k] = v + if self.device: + self.mapping_tensor = self.mapping_tensor.to(self.device) + + def forward( + self, + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + student_labels: torch.Tensor, + teacher_labels: torch.Tensor, + student_input_ids: torch.Tensor, + teacher_input_ids: torch.Tensor, + ) -> torch.Tensor: + """ + 计算GOLD/ULD损失 + + Args: + student_logits: [batch_size, seq_len, student_vocab_size] + teacher_logits: [batch_size, seq_len, teacher_vocab_size] + student_labels: [batch_size, seq_len], -100表示忽略 + teacher_labels: [batch_size, seq_len], -100表示忽略 + student_input_ids: [batch_size, seq_len] + teacher_input_ids: [batch_size, seq_len] + + Returns: + loss: scalar tensor + """ + + if not self.use_uld_loss: + return torch.tensor(0.0, device=student_logits.device, requires_grad=True) + + # 1. Cross-entropy loss (可选,通过crossentropy_weight设置权重) + crossentropy_loss = self._compute_cross_entropy(student_logits, student_labels) + + # 2. Distillation loss (ULD) + distillation_loss = self._compute_distillation_loss(student_logits, teacher_logits, student_labels, + teacher_labels, student_input_ids, teacher_input_ids) + return crossentropy_loss + distillation_loss + + def _compute_cross_entropy(self, student_logits: torch.Tensor, student_labels: torch.Tensor) -> torch.Tensor: + """计算cross-entropy loss""" + if self.crossentropy_weight <= 0: + return torch.tensor(0.0, device=student_logits.device, requires_grad=True) + + shift_logits = student_logits[..., :-1, :].contiguous() + shift_labels = student_labels[..., 1:].contiguous() + + loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + ce_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + return self.crossentropy_weight * ce_loss + + def _compute_distillation_loss( + self, + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + student_labels: torch.Tensor, + teacher_labels: torch.Tensor, + student_input_ids: torch.Tensor, + teacher_input_ids: torch.Tensor, + ) -> torch.Tensor: + """计算ULD蒸馏损失""" + # 获取答案区域 + student_answer_idx, student_answer_size = self._get_answer_regions(student_labels) + teacher_answer_idx, teacher_answer_size = self._get_answer_regions(teacher_labels) + + if self.skip_student_eos: + student_answer_size = [s - 1 for s in student_answer_size] + if self.skip_teacher_eos: + teacher_answer_size = [t - 1 for t in teacher_answer_size] + + # 边界检查 + if not student_answer_size or not teacher_answer_size: + return torch.zeros(1, device=student_logits.device, requires_grad=True) * 1e-8 + + batch_size = student_logits.size(0) + distillation_losses = [] + + for i in range(batch_size): + s_start = student_answer_idx[i] + s_size = student_answer_size[i] + t_start = teacher_answer_idx[i] + t_size = teacher_answer_size[i] + + if s_size <= 0 or t_size <= 0: + loss_i = student_logits[i].sum() * 0.0 + # Ensure the loss tensor requires gradients + loss_i = loss_i.detach().requires_grad_(True) + distillation_losses.append(loss_i) + continue + + # 提取答案logits + student_ans_logits = student_logits[i, s_start:s_start + s_size] + teacher_ans_logits = teacher_logits[i, t_start:t_start + t_size] + + # 转换为概率 + student_probs = F.softmax(student_ans_logits / self.student_temperature, dim=-1) + teacher_probs = F.softmax(teacher_ans_logits / self.teacher_temperature, dim=-1) + + student_token_ids = student_input_ids[i, s_start:s_start + s_size].tolist() + teacher_token_ids = teacher_input_ids[i, t_start:t_start + t_size].tolist() + + # Token对齐 + if self.use_extended_uld: + student_groups, teacher_groups = self._build_alignment_groups_from_ids( + student_token_ids, teacher_token_ids) + + student_aligned = self._merge_probabilities_with_groups(student_probs, student_groups, + student_token_ids) + teacher_aligned = self._merge_probabilities_with_groups(teacher_probs, teacher_groups, + teacher_token_ids) + + else: + min_len = min(len(student_token_ids), len(teacher_token_ids)) + student_aligned = student_probs[:min_len] + teacher_aligned = teacher_probs[:min_len] + + # 计算损失 + if self.use_hybrid_loss and self._vocab_mapping: + aligned_loss = self._compute_hybrid_uld_loss(student_aligned, teacher_aligned) + else: + aligned_loss = self._compute_basic_uld_loss(student_aligned, teacher_aligned) + + distillation_losses.append(aligned_loss) + distillation_loss = torch.stack(distillation_losses).mean() + return self.distillation_weight * distillation_loss + + def _get_answer_regions(self, labels: torch.Tensor) -> Tuple[List[int], List[int]]: + """获取答案区域的起始位置和大小""" + indices = [] + sizes = [] + + for label in labels: + mask = label.ne(self.ignore_index) + if not mask.any(): + indices.append(0) + sizes.append(0) + continue + + valid_indices = mask.nonzero(as_tuple=True)[0] + indices.append(int(valid_indices[0].item())) + sizes.append(int(mask.sum().item())) + + return indices, sizes + + def _build_alignment_groups_from_ids(self, student_token_ids: List[int], + teacher_token_ids: List[int]) -> Tuple[List[List[int]], List[List[int]]]: + """ + 基于文本内容构建对齐组 + 使用贪心子串匹配算法 + """ + + def decode_tokens(tokenizer, token_ids): + pieces = [] + prev = '' + for k in range(len(token_ids)): + cur = tokenizer.decode(token_ids[:k + 1], skip_special_tokens=False) + pieces.append(cur[len(prev):]) + prev = cur + return pieces + + student_pieces = decode_tokens(self.student_tokenizer, student_token_ids) + teacher_pieces = decode_tokens(self.teacher_tokenizer, teacher_token_ids) + + # 贪心匹配算法 + student_groups = [] + teacher_groups = [] + s_idx = 0 + t_idx = 0 + + while s_idx < len(student_pieces) and t_idx < len(teacher_pieces): + student_text = '' + teacher_text = '' + student_group = [] + teacher_group = [] + + # 尝试找到最短的连续匹配序列 + while s_idx < len(student_pieces) and t_idx < len(teacher_pieces): + if not student_group: + student_group.append(s_idx) + student_text += student_pieces[s_idx] + s_idx += 1 + + if not teacher_group: + teacher_group.append(t_idx) + teacher_text += teacher_pieces[t_idx] + t_idx += 1 + + # 检查是否匹配 + if student_text == teacher_text: + student_groups.append(student_group) + teacher_groups.append(teacher_group) + break + elif len(student_text) < len(teacher_text): + if s_idx < len(student_pieces): + student_group.append(s_idx) + student_text += student_pieces[s_idx] + s_idx += 1 + else: + break + else: + if t_idx < len(teacher_pieces): + teacher_group.append(t_idx) + teacher_text += teacher_pieces[t_idx] + t_idx += 1 + else: + break + else: + # 未完全匹配,添加剩余部分 + if student_group and teacher_group: + student_groups.append(student_group) + teacher_groups.append(teacher_group) + + return student_groups, teacher_groups + + def _merge_probabilities_with_groups( + self, + probs: torch.Tensor, + alignment_groups: List[List[int]], + token_ids: List[int], + ) -> torch.Tensor: + """ + 根据对齐组合并概率分布 + 使用链式法则: P_merged = P(y|x_0) * P(x_1|x_0) * P(x_2|x_0,x_1) * ... + """ + aligned_probs = [] + + for group in alignment_groups: + if len(group) > 1: + # 第一个token的边际概率 + marginal_probs = probs[group[0]] # [vocab_size] + + # 后续token的条件概率(标量) + conditional_product = 1.0 + for k in range(1, len(group)): + cond_prob = probs[group[k], token_ids[group[k - 1]]] + conditional_product *= cond_prob + + merged_probs = marginal_probs * conditional_product + aligned_probs.append(merged_probs) + elif len(group) == 1: + aligned_probs.append(probs[group[0]]) + + if aligned_probs: + return torch.stack(aligned_probs) + else: + # 返回一个空的但需要梯度的张量 + empty_tensor = probs[:0].detach().requires_grad_(True) + return empty_tensor + + def _compute_basic_uld_loss( + self, + student_aligned: torch.Tensor, + teacher_aligned: torch.Tensor, + ) -> torch.Tensor: + """基础ULD损失:排序后的L1距离""" + student_sorted = student_aligned.sort(dim=-1, descending=True).values + teacher_sorted = teacher_aligned.sort(dim=-1, descending=True).values + + # Padding到相同vocab size + s_vocab = student_sorted.size(-1) + t_vocab = teacher_sorted.size(-1) + max_vocab = max(s_vocab, t_vocab) + + if s_vocab < max_vocab: + student_sorted = F.pad(student_sorted, (0, max_vocab - s_vocab)) + if t_vocab < max_vocab: + teacher_sorted = F.pad(teacher_sorted, (0, max_vocab - t_vocab)) + + loss = F.l1_loss(student_sorted, teacher_sorted, reduction='sum') + loss /= student_aligned.size(0) + + return loss + + def _compute_hybrid_uld_loss( + self, + student_aligned: torch.Tensor, + teacher_aligned: torch.Tensor, + ) -> torch.Tensor: + """混合ULD损失:matched用JSD,unmatched用排序L1""" + device = student_aligned.device + s_vocab = student_aligned.size(-1) + t_vocab = teacher_aligned.size(-1) + + # 创建matched/unmatched masks + if self._teacher_matched_ids: + teacher_matched_idx = torch.tensor(sorted(self._teacher_matched_ids), dtype=torch.long, device=device) + student_matched_idx = self.mapping_tensor[teacher_matched_idx] + else: + teacher_matched_idx = torch.tensor([], dtype=torch.long, device=device) + student_matched_idx = torch.tensor([], dtype=torch.long, device=device) + + teacher_matched_mask = torch.zeros(t_vocab, dtype=torch.bool, device=device) + student_matched_mask = torch.zeros(s_vocab, dtype=torch.bool, device=device) + + if len(teacher_matched_idx) > 0: + teacher_matched_mask[teacher_matched_idx] = True + student_matched_mask[student_matched_idx] = True + + # 1. Matched tokens的JSD损失 + matched_loss = torch.tensor(0.0, device=device, requires_grad=True) + matched_count = 0 + + if len(teacher_matched_idx) > 0: + teacher_matched_probs = teacher_aligned[:, teacher_matched_idx] + student_matched_probs = student_aligned[:, student_matched_idx] + matched_count = teacher_matched_probs.size(-1) + + matched_loss = self._compute_jsd_for_matched(student_matched_probs, teacher_matched_probs) + # 2. Unmatched tokens的排序L1损失 + teacher_unmatched = teacher_aligned[:, ~teacher_matched_mask] + student_unmatched = student_aligned[:, ~student_matched_mask] + + unmatched_loss = torch.tensor(0.0, device=device, requires_grad=True) + if teacher_unmatched.size(-1) > 0 and student_unmatched.size(-1) > 0: + teacher_sorted = teacher_unmatched.sort(dim=-1, descending=True).values + student_sorted = student_unmatched.sort(dim=-1, descending=True).values + + t_size = teacher_sorted.size(-1) + s_size = student_sorted.size(-1) + max_size = max(t_size, s_size) + + if t_size < max_size: + teacher_sorted = F.pad(teacher_sorted, (0, max_size - t_size)) + if s_size < max_size: + student_sorted = F.pad(student_sorted, (0, max_size - s_size)) + + unmatched_loss = F.l1_loss(student_sorted, teacher_sorted, reduction='sum') + unmatched_loss /= student_aligned.size(0) + + # 3. 加权组合 + if self.hybrid_matched_weight is None: + w_matched = matched_count / max(1, t_vocab) + w_unmatched = 1.0 - w_matched + else: + w_matched = self.hybrid_matched_weight + w_unmatched = self.hybrid_unmatched_weight + + total_loss = w_matched * matched_loss + w_unmatched * unmatched_loss + + # 保存用于logging + self.last_matched_loss = matched_loss + self.last_unmatched_loss = unmatched_loss + + return total_loss + + def _compute_jsd_for_matched(self, + student_probs: torch.Tensor, + teacher_probs: torch.Tensor, + epsilon: float = 1e-8) -> torch.Tensor: + """计算matched tokens的JSD损失,添加数值稳定性处理""" + batch_seq_len, num_matched = student_probs.shape + + # 检查输入概率分布是否有效 + if torch.isnan(student_probs).any() or torch.isnan(teacher_probs).any(): + return torch.tensor(0.0, device=student_probs.device, requires_grad=True) + + # 添加epsilon防止数值下溢和log(0) + student_probs = student_probs.clamp(min=epsilon) + teacher_probs = teacher_probs.clamp(min=epsilon) + + # 重新归一化概率分布 + student_probs = student_probs / student_probs.sum(dim=-1, keepdim=True) + teacher_probs = teacher_probs / teacher_probs.sum(dim=-1, keepdim=True) + + student_flat = student_probs.view(-1, num_matched) + teacher_flat = teacher_probs.view(-1, num_matched) + + # JSD = 0.5 * KL(P||M) + 0.5 * KL(Q||M), where M = 0.5*(P+Q) + m = 0.5 * (student_flat + teacher_flat) + + # 添加epsilon到中间分布 + m = m.clamp(min=epsilon) + m = m / m.sum(dim=-1, keepdim=True) + + # 直接对概率分布取对数,添加epsilon防止数值问题 + log_m = torch.log(m + epsilon) + log_student = torch.log(student_flat + epsilon) + log_teacher = torch.log(teacher_flat + epsilon) + + # 使用log_target=True,传入log概率 + kl_p_m = F.kl_div(log_m, log_student, reduction='batchmean', log_target=True) + kl_q_m = F.kl_div(log_m, log_teacher, reduction='batchmean', log_target=True) + jsd = 0.5 * (kl_p_m + kl_q_m) + + # 检查结果是否有效 + if torch.isnan(jsd) or torch.isinf(jsd): + return torch.tensor(0.0, device=student_probs.device, requires_grad=True) + + return jsd