[megatron] support gemma4 megatron#9296
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the documentation to reflect support for Gemma 4 models and refactors embedding handling in Megatron utilities to support multiple modules during device conversion. In swift/model/models/gemma.py, a suggestion was made to use inputs_embeds.device instead of multimodal_mask.device when moving the pad_embedding tensor to ensure better robustness and consistency across operands in the subsequent torch.where call.
|
|
||
| if self.config.get_text_config().hidden_size_per_layer_input: | ||
| pad_embedding = self.language_model.embed_tokens.weight[self.config.text_config.pad_token_id, :] | ||
| pad_embedding = pad_embedding.to(multimodal_mask.device) |
There was a problem hiding this comment.
Using inputs_embeds.device as the target for the .to() call is generally more robust than multimodal_mask.device. Since inputs_embeds is the primary tensor representing the hidden states in this operation, it serves as the most reliable reference for the execution device, ensuring consistency across all operands in the subsequent torch.where call.
| pad_embedding = pad_embedding.to(multimodal_mask.device) | |
| pad_embedding = pad_embedding.to(inputs_embeds.device) |
No description provided.