fix(models): Fix dtype mismatch in SwitchTransformers and TimmWrapperModel#45074
Open
harshaljanjani wants to merge 2 commits intohuggingface:mainfrom
Open
Conversation
Contributor
|
[For maintainers] Suggested jobs to run (before merge) run-slow: switch_transformers, timm_wrapper |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
The following dtype mismatch use cases were identified and fixed in this PR:
→ Switch Transformers: 7938e91fa refactored all MoE models for vLLM compatibility; in that refactor, the
_cast_classifier()method was removed fromSwitchTransformersTop1Routerbut no dtype cast was added. Casting hidden_states toclassifier.weight.dtypebefore the linear call fixes that!→ TimmWrapper: 6217adc6c8 changed the default dtype behavior to "auto"; in that commit,
pixel_values.to(self.device, self.dtype)was regressed to pixel_values.to(self.device) dropping the dtype cast. I'm not too sure why it was dropped; but restoring it seems logical to fix the use case.→ For more details on reproducing the bug and the output screenshots, please visit the linked issue!
cc: @Rocketknight1
Fixes #45072
CI run test coverage of this behavior (as suggested by @ydshieh) :):
SwitchTransformers:
→
test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_generate_with_past_key_values→
test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_model_fp16_forward→
test_modeling_switch_transformers.py::SwitchTransformerModelIntegrationTests::test_small_logitsTimmWrapper:
→
TimmWrapperModelTestdoes not have explicit bfloat16 forward pass tests; added one in this PR for complete coverage.Repro output after the fixes (feel free to cross-check):
Code Agent Policy
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.