Skip to content

Commit 6d0d2ae

Browse files
committed
Fix Megatron rope theta compatibility
1 parent 9d75910 commit 6d0d2ae

1 file changed

Lines changed: 16 additions & 0 deletions

File tree

src/art/megatron/provider.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,21 @@ def has_glob(self, pattern: str) -> bool:
5757
return self._source.has_glob(pattern)
5858

5959

60+
def _ensure_rope_theta(bridge: AutoBridge) -> None:
61+
config = bridge.hf_pretrained.config
62+
if hasattr(config, "rope_theta"):
63+
return
64+
65+
for rope_config_name in ("rope_scaling", "rope_parameters"):
66+
rope_config = getattr(config, rope_config_name, None)
67+
if not isinstance(rope_config, dict):
68+
continue
69+
rope_theta = rope_config.get("rope_theta")
70+
if isinstance(rope_theta, int | float):
71+
setattr(config, "rope_theta", float(rope_theta))
72+
return
73+
74+
6075
def get_provider(
6176
model: str,
6277
*,
@@ -70,6 +85,7 @@ def get_provider(
7085
assert isinstance(bridge._model_bridge, Qwen3MoEBridge), (
7186
"Only Qwen3 MoE models are supported"
7287
)
88+
_ensure_rope_theta(bridge)
7389
if torch_dtype != torch.bfloat16:
7490
model_name_or_path = bridge.hf_pretrained.model_name_or_path
7591
assert model_name_or_path is not None

0 commit comments

Comments
 (0)