File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -57,6 +57,21 @@ def has_glob(self, pattern: str) -> bool:
5757 return self ._source .has_glob (pattern )
5858
5959
60+ def _ensure_rope_theta (bridge : AutoBridge ) -> None :
61+ config = bridge .hf_pretrained .config
62+ if hasattr (config , "rope_theta" ):
63+ return
64+
65+ for rope_config_name in ("rope_scaling" , "rope_parameters" ):
66+ rope_config = getattr (config , rope_config_name , None )
67+ if not isinstance (rope_config , dict ):
68+ continue
69+ rope_theta = rope_config .get ("rope_theta" )
70+ if isinstance (rope_theta , int | float ):
71+ setattr (config , "rope_theta" , float (rope_theta ))
72+ return
73+
74+
6075def get_provider (
6176 model : str ,
6277 * ,
@@ -70,6 +85,7 @@ def get_provider(
7085 assert isinstance (bridge ._model_bridge , Qwen3MoEBridge ), (
7186 "Only Qwen3 MoE models are supported"
7287 )
88+ _ensure_rope_theta (bridge )
7389 if torch_dtype != torch .bfloat16 :
7490 model_name_or_path = bridge .hf_pretrained .model_name_or_path
7591 assert model_name_or_path is not None
You can’t perform that action at this time.
0 commit comments