|
1 | 1 | from pathlib import Path |
2 | 2 |
|
| 3 | +import pytest |
3 | 4 | import torch |
4 | | -from transformers import AutoModelForCausalLM |
| 5 | +from transformers import AutoModelForCausalLM, PreTrainedModel |
5 | 6 |
|
6 | 7 | from modalities.config.config import load_app_config_dict |
7 | 8 | from modalities.conversion.gpt2.conversion_model import check_converted_model |
8 | 9 | from modalities.conversion.gpt2.convert_gpt2 import convert_gpt2 |
| 10 | +from modalities.models.gpt2.gpt2_model import GPT2LLM |
9 | 11 | from modalities.models.utils import ModelTypeEnum, get_model_from_config |
10 | 12 | from tests.conversion.gpt2.helper import check_same_weight_model |
11 | 13 |
|
12 | 14 |
|
13 | | -def test_converting_gpt2_does_not_change_weights(tmp_path: Path, gpt2_config_path: str): |
14 | | - output_dir = tmp_path / "output" |
15 | | - convert_gpt2(gpt2_config_path, output_dir) |
16 | | - modalities_config = load_app_config_dict(gpt2_config_path) |
17 | | - original_model = get_model_from_config(modalities_config, model_type=ModelTypeEnum.CHECKPOINTED_MODEL) |
18 | | - converted_model = AutoModelForCausalLM.from_pretrained(output_dir, local_files_only=True, trust_remote_code=True) |
| 15 | +def test_converting_gpt2_does_not_change_weights(converted_model: PreTrainedModel, original_model: GPT2LLM): |
19 | 16 | check_same_weight_model(converted_model, original_model) |
20 | 17 |
|
21 | 18 |
|
22 | | -def test_converting_gpt2_does_not_change_outputs(tmp_path: Path, gpt2_config_path: str): |
23 | | - output_dir = tmp_path / "output" |
24 | | - convert_gpt2(gpt2_config_path, output_dir) |
25 | | - modalities_config = load_app_config_dict(gpt2_config_path) |
26 | | - original_model = get_model_from_config(modalities_config, model_type=ModelTypeEnum.CHECKPOINTED_MODEL) |
27 | | - converted_model = AutoModelForCausalLM.from_pretrained( |
28 | | - output_dir, local_files_only=True, trust_remote_code=True |
29 | | - ).to(dtype=torch.bfloat16) |
30 | | - vocab_size = modalities_config["model_raw" if "model_raw" in modalities_config else "model"]["config"]["vocab_size"] |
| 19 | +def test_converting_gpt2_does_not_change_outputs( |
| 20 | + converted_model: PreTrainedModel, original_model: GPT2LLM, vocab_size: int |
| 21 | +): |
31 | 22 | check_converted_model( |
32 | 23 | hf_model=converted_model, modalities_model=original_model, num_testruns=1, vocab_size=vocab_size |
33 | 24 | ) |
| 25 | + |
| 26 | + |
| 27 | +@pytest.fixture |
| 28 | +def converted_model(run_convert_gpt2: None, output_dir: Path) -> PreTrainedModel: |
| 29 | + return AutoModelForCausalLM.from_pretrained(output_dir, local_files_only=True, trust_remote_code=True).to( |
| 30 | + dtype=torch.bfloat16 |
| 31 | + ) |
| 32 | + |
| 33 | + |
| 34 | +@pytest.fixture |
| 35 | +def run_convert_gpt2(gpt2_config_path: str, output_dir: Path): |
| 36 | + convert_gpt2(gpt2_config_path, output_dir) |
| 37 | + |
| 38 | + |
| 39 | +@pytest.fixture |
| 40 | +def original_model(gpt2_config_path: str) -> GPT2LLM: |
| 41 | + modalities_config = load_app_config_dict(gpt2_config_path) |
| 42 | + return get_model_from_config(modalities_config, model_type=ModelTypeEnum.CHECKPOINTED_MODEL) |
| 43 | + |
| 44 | + |
| 45 | +@pytest.fixture |
| 46 | +def vocab_size(gpt2_config_path: str) -> int: |
| 47 | + modalities_config = load_app_config_dict(gpt2_config_path) |
| 48 | + return modalities_config["model_raw" if "model_raw" in modalities_config else "model"]["config"]["vocab_size"] |
| 49 | + |
| 50 | + |
| 51 | +@pytest.fixture |
| 52 | +def output_dir(tmp_path: Path) -> Path: |
| 53 | + return tmp_path / "output" |
0 commit comments