Skip to content

fix(models): Fix dtype mismatch in SwitchTransformers and TimmWrapperModel#45074

Open
harshaljanjani wants to merge 2 commits intohuggingface:mainfrom
harshaljanjani:fix/switch-transformers-timm-wrapper-bf16-dtype
Open

fix(models): Fix dtype mismatch in SwitchTransformers and TimmWrapperModel#45074
harshaljanjani wants to merge 2 commits intohuggingface:mainfrom
harshaljanjani:fix/switch-transformers-timm-wrapper-bf16-dtype

Conversation

@harshaljanjani
Copy link
Copy Markdown
Contributor

@harshaljanjani harshaljanjani commented Mar 27, 2026

What does this PR do?

The following dtype mismatch use cases were identified and fixed in this PR:

Switch Transformers: 7938e91fa refactored all MoE models for vLLM compatibility; in that refactor, the _cast_classifier() method was removed from SwitchTransformersTop1Router but no dtype cast was added. Casting hidden_states to classifier.weight.dtype before the linear call fixes that!
TimmWrapper: 6217adc6c8 changed the default dtype behavior to "auto"; in that commit, pixel_values.to(self.device, self.dtype) was regressed to pixel_values.to(self.device) dropping the dtype cast. I'm not too sure why it was dropped; but restoring it seems logical to fix the use case.
→ For more details on reproducing the bug and the output screenshots, please visit the linked issue!

cc: @Rocketknight1

Fixes #45072

CI run test coverage of this behavior (as suggested by @ydshieh) :):

SwitchTransformers:
test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_generate_with_past_key_values
test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_model_fp16_forward
test_modeling_switch_transformers.py::SwitchTransformerModelIntegrationTests::test_small_logits
TimmWrapper:
TimmWrapperModelTest does not have explicit bfloat16 forward pass tests; added one in this PR for complete coverage.

Repro output after the fixes (feel free to cross-check):

1 1-1

Code Agent Policy

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you fix any necessary existing tests?

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: switch_transformers, timm_wrapper

@harshaljanjani harshaljanjani marked this pull request as ready for review March 27, 2026 20:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG][CI] SwitchTransformers and TimmWrapperModel dtype mismatches in bfloat16 inference

1 participant