44
55from modalities .conversion .gpt2 .configuration_gpt2 import GPT2Config
66from modalities .conversion .gpt2 .modeling_gpt2 import GPT2DecoderLayer , GPT2ForCausalLM
7+ from modalities .models .components .layer_norms import LayerNormConfig
78from modalities .models .gpt2 .gpt2_model import GPT2LLM , GPT2Block , PositionTypes
89from modalities .models .model import SwiGLU
910from modalities .models .utils import ModelTypeEnum , get_model_from_config
@@ -27,26 +28,6 @@ def convert_model_checkpoint(modalities_config: dict) -> tuple[GPT2ForCausalLM,
2728 return hf_model , modalities_model
2829
2930
30- def _check_conversion_criteria (model_config : dict ) -> None :
31- """Checks that the modalities config fulfills criteria necessary for conversion
32-
33- Args:
34- model_config (dict): model or model_raw part of the Modalities config dictionary.
35-
36- Returns:
37- None
38- """
39- assert model_config ["poe_type" ] == PositionTypes .NOPE
40- assert model_config ["bias" ] is False
41- assert model_config ["activation_type" ] == "swiglu"
42- assert model_config ["attention_implementation" ] in ["pytorch_flash" , "manual" ]
43-
44- for norm in ["attention_norm" , "ffn_norm" , "lm_head_norm" ]:
45- assert model_config [norm ]["variant_key" ] == "layer_norm"
46- assert model_config [norm ]["config" ].get ("elementwise_affine" , True ) is True # True = default setting
47- assert model_config [norm ]["config" ].get ("bias" , True ) is True # True = default setting
48-
49-
5031def convert_model_config (modalities_config : dict ) -> GPT2Config :
5132 """Converts the modalities model configuration to a Huggingface transformers configuration.
5233 For this the model_raw or model section of the modalities config is used.
@@ -59,16 +40,8 @@ def convert_model_config(modalities_config: dict) -> GPT2Config:
5940 GPT2Config: Converted Huggingface model configuration.
6041 """
6142 config = modalities_config ["model_raw" if "model_raw" in modalities_config else "model" ]["config" ]
62-
6343 _check_conversion_criteria (config )
6444
65- if config ["attention_implementation" ] == "pytorch_flash" :
66- attention_impl = "sdpa"
67- elif config ["attention_implementation" ] == "manual" :
68- attention_impl = "eager"
69- else :
70- raise ValueError (f"Unknown or unsupported attention implementation { config ['attention_implementation' ]} ." )
71-
7245 return GPT2Config (
7346 vocab_size = config ["vocab_size" ],
7447 hidden_size = config ["n_embd" ],
@@ -80,15 +53,12 @@ def convert_model_config(modalities_config: dict) -> GPT2Config:
8053 attention_bias = config ["bias" ],
8154 mlp_bias = config ["bias" ],
8255 hidden_act = "silu" ,
83- layer_norm_eps = config ["ffn_norm" ]["config" ]["eps" ],
84- layer_norm_elementwise_affine = config ["ffn_norm" ]["config" ].get (
85- "elementwise_affine" ,
86- True ,
87- ),
88- layer_norm_bias = config ["ffn_norm" ]["config" ].get ("bias" , True ),
56+ layer_norm_eps = _get_layer_norm_value (config ["ffn_norm" ]["config" ], "eps" ),
57+ layer_norm_elementwise_affine = _get_layer_norm_value (config ["ffn_norm" ]["config" ], "elementwise_affine" ),
58+ layer_norm_bias = _get_layer_norm_value (config ["ffn_norm" ]["config" ], "bias" ),
8959 max_position_embeddings = config ["sequence_length" ],
9060 rope_theta = config ["attention_config" ]["qkv_transforms" ][0 ]["config" ]["base_freq" ],
91- _attn_implementation = attention_impl ,
61+ _attn_implementation = _map_attention_type ( config ) ,
9262 output_attentions = False ,
9363 )
9464
@@ -114,21 +84,64 @@ def check_converted_model(hf_model: GPT2ForCausalLM, modalities_model: GPT2LLM,
11484 assert torch .equal (llama_logits , modalities_logits )
11585
11686
117- def _copy_weights_model (hf_model_model : GPT2ForCausalLM , modalities_model : GPT2LLM ):
87+ def _check_conversion_criteria (model_config : dict ) -> None :
88+ """Checks that the modalities config fulfills criteria necessary for conversion
89+
90+ Args:
91+ model_config (dict): model or model_raw part of the Modalities config dictionary.
92+
93+ Returns:
94+ None
95+ """
96+ assert model_config ["poe_type" ] == PositionTypes .NOPE
97+ assert model_config ["activation_type" ] == "swiglu"
98+ assert model_config ["attention_implementation" ] in ["pytorch_flash" , "manual" ]
99+
100+ norms = ["attention_norm" , "ffn_norm" , "lm_head_norm" ]
101+ for norm in norms :
102+ assert model_config [norm ]["variant_key" ] == "layer_norm"
103+
104+ assert (
105+ len (set (_get_layer_norm_value (model_config [norm ]["config" ], "bias" ) for norm in norms )) == 1
106+ ), "All norms must have the same bias setting."
107+ assert (
108+ len (set (_get_layer_norm_value (model_config [norm ]["config" ], "elementwise_affine" ) for norm in norms )) == 1
109+ ), "All norms must have the same elementwise_affine setting."
110+ assert (
111+ len (set (_get_layer_norm_value (model_config [norm ]["config" ], "eps" ) for norm in norms )) == 1
112+ ), "All norms must have the same eps setting."
113+
114+
115+ def _get_layer_norm_value (config : dict , field : str ) -> bool | float | int :
116+ default = LayerNormConfig .model_fields [field ].default
117+ return config .get (field , default )
118+
119+
120+ def _map_attention_type (config ):
121+ if config ["attention_implementation" ] == "pytorch_flash" :
122+ attention_impl = "sdpa"
123+ elif config ["attention_implementation" ] == "manual" :
124+ attention_impl = "eager"
125+ else :
126+ raise ValueError (f"Unknown or unsupported attention implementation { config ['attention_implementation' ]} ." )
127+ return attention_impl
128+
129+
130+ def _copy_weights_model (hf_model : GPT2ForCausalLM , modalities_model : GPT2LLM ):
118131 """Copies the weights of the modalities model to the Huggingface transformers model.
119132
120133 Args:
121134 hf_model_model (GPT2ForCausalLM): The uninitialized Huggingface transformers model.
122135 The weights will be copied here.
123136 modalities_model (GPT2LLM): The modalities model from which the weights will be copied.
124137 """
125- hf_model_model .model .embed_tokens .weight .data .copy_ (modalities_model .transformer .wte .weight .data )
126- for hf_layer , modalities_layer in zip (hf_model_model .model .layers , modalities_model .transformer .h ):
138+ hf_model .model .embed_tokens .weight .data .copy_ (modalities_model .transformer .wte .weight .data )
139+ for hf_layer , modalities_layer in zip (hf_model .model .layers , modalities_model .transformer .h ):
127140 _copy_weights_attention (hf_layer , modalities_layer )
128141 _copy_weights_mlp (hf_layer , modalities_layer )
129142 _copy_weights_layer_norms (hf_layer , modalities_layer )
130- _copy_weights_base_modules (hf_model_model .lm_head , modalities_model .lm_head )
131- _copy_weights_base_modules (hf_model_model .model .norm , modalities_model .transformer .lm_head_norm )
143+ _copy_weights_base_modules (hf_model .lm_head , modalities_model .lm_head )
144+ _copy_weights_base_modules (hf_model .model .norm , modalities_model .transformer .lm_head_norm )
132145
133146
134147def _copy_weights_attention (hf_layer : GPT2DecoderLayer , modalities_layer : GPT2Block ):
0 commit comments