Skip to content

Commit 8f84b2d

Browse files
authored
Merge pull request #437 from Modalities/fix_compile_weight_init_bug
Fixes non-catching weight init regexes as torch.compile changes the FQNs
2 parents 627f84e + 8a1ae86 commit 8f84b2d

2 files changed

Lines changed: 8 additions & 2 deletions

File tree

src/modalities/models/gpt2/llama3_like_initialization.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
9999
}
100100

101101
def initialize_in_place(self, model: nn.Module):
102-
self._init_by_fqn_regex(model, self.regex_to_init, depth_init=self.depth_init)
102+
self._init_by_fqn_regex(model, self.regex_to_init)
103103

104104
@staticmethod
105-
def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable, dict]], depth_init: bool):
105+
def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable, dict]]):
106106
hits = {k: 0 for k in regex_to_init.keys()}
107107

108108
for parameter_name, p in model.named_parameters():
@@ -112,6 +112,9 @@ def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable
112112
)
113113
match_count = 0
114114
for weight_regex in regex_to_init.keys():
115+
parameter_name = parameter_name.replace(
116+
"_orig_mod.", ""
117+
) # remove FQN modification from torch.compile if present
115118
if re.fullmatch(weight_regex, parameter_name):
116119
init_fn, arg_dict = regex_to_init[weight_regex]
117120
if arg_dict["std"] is not None and callable(arg_dict["std"]):

src/modalities/nn/model_initialization/initialization_routines.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def initialize_in_place(self, model: nn.Module):
4848
weight_regexes = self.parameter_name_regexes.weights
4949
bias_regexes = self.parameter_name_regexes.biases
5050
for parameter_name, p in model.named_parameters():
51+
parameter_name = parameter_name.replace(
52+
"_orig_mod.", ""
53+
) # remove FQN modification from torch.compile if present
5154
for weight_regex in weight_regexes:
5255
if re.fullmatch(weight_regex, parameter_name):
5356
nn.init.normal_(p, mean=self.mean, std=self.std)

0 commit comments

Comments
 (0)