Skip to content

Commit 000a9fa

Browse files
committed
test(huggingface): Added one additional test and refactored some others.
1 parent 60bab8a commit 000a9fa

3 files changed

Lines changed: 45 additions & 18 deletions

File tree

tests/conversion/gpt2/conftest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
from tests.conftest import _ROOT_DIR
1212

1313

14-
@pytest.fixture()
15-
def gpt2_config_path(tmp_path: Path, initialized_model: GPT2LLM, config_file_path: str) -> str:
14+
@pytest.fixture
15+
def gpt2_config_path(tmpdir_factory: pytest.TempdirFactory, initialized_model: GPT2LLM, config_file_path: str) -> str:
16+
tmp_path = tmpdir_factory.mktemp("gpt2_model")
1617
new_config_filename = tmp_path / "gpt2_config_test.yaml"
1718
model_path = tmp_path / "model.pth"
1819
shutil.copy(config_file_path, new_config_filename)

tests/conversion/gpt2/test_conversion_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111
from tests.conversion.gpt2.helper import check_same_weight_base_modules, check_same_weight_model
1212

1313

14+
def test_convert_model_can_generate(gpt2_config_path: str):
15+
modalities_config = load_app_config_dict(gpt2_config_path)
16+
hf_model, _ = convert_model_checkpoint(modalities_config)
17+
assert hf_model.can_generate()
18+
19+
1420
def test_convert_model_checkpoint_does_not_change_weights(gpt2_config_path: str):
1521
modalities_config = load_app_config_dict(gpt2_config_path)
1622
hf_model, modalities_model = convert_model_checkpoint(modalities_config)
Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,53 @@
11
from pathlib import Path
22

3+
import pytest
34
import torch
4-
from transformers import AutoModelForCausalLM
5+
from transformers import AutoModelForCausalLM, PreTrainedModel
56

67
from modalities.config.config import load_app_config_dict
78
from modalities.conversion.gpt2.conversion_model import check_converted_model
89
from modalities.conversion.gpt2.convert_gpt2 import convert_gpt2
10+
from modalities.models.gpt2.gpt2_model import GPT2LLM
911
from modalities.models.utils import ModelTypeEnum, get_model_from_config
1012
from tests.conversion.gpt2.helper import check_same_weight_model
1113

1214

13-
def test_converting_gpt2_does_not_change_weights(tmp_path: Path, gpt2_config_path: str):
14-
output_dir = tmp_path / "output"
15-
convert_gpt2(gpt2_config_path, output_dir)
16-
modalities_config = load_app_config_dict(gpt2_config_path)
17-
original_model = get_model_from_config(modalities_config, model_type=ModelTypeEnum.CHECKPOINTED_MODEL)
18-
converted_model = AutoModelForCausalLM.from_pretrained(output_dir, local_files_only=True, trust_remote_code=True)
15+
def test_converting_gpt2_does_not_change_weights(converted_model: PreTrainedModel, original_model: GPT2LLM):
1916
check_same_weight_model(converted_model, original_model)
2017

2118

22-
def test_converting_gpt2_does_not_change_outputs(tmp_path: Path, gpt2_config_path: str):
23-
output_dir = tmp_path / "output"
24-
convert_gpt2(gpt2_config_path, output_dir)
25-
modalities_config = load_app_config_dict(gpt2_config_path)
26-
original_model = get_model_from_config(modalities_config, model_type=ModelTypeEnum.CHECKPOINTED_MODEL)
27-
converted_model = AutoModelForCausalLM.from_pretrained(
28-
output_dir, local_files_only=True, trust_remote_code=True
29-
).to(dtype=torch.bfloat16)
30-
vocab_size = modalities_config["model_raw" if "model_raw" in modalities_config else "model"]["config"]["vocab_size"]
19+
def test_converting_gpt2_does_not_change_outputs(
20+
converted_model: PreTrainedModel, original_model: GPT2LLM, vocab_size: int
21+
):
3122
check_converted_model(
3223
hf_model=converted_model, modalities_model=original_model, num_testruns=1, vocab_size=vocab_size
3324
)
25+
26+
27+
@pytest.fixture
28+
def converted_model(run_convert_gpt2: None, output_dir: Path) -> PreTrainedModel:
29+
return AutoModelForCausalLM.from_pretrained(output_dir, local_files_only=True, trust_remote_code=True).to(
30+
dtype=torch.bfloat16
31+
)
32+
33+
34+
@pytest.fixture
35+
def run_convert_gpt2(gpt2_config_path: str, output_dir: Path):
36+
convert_gpt2(gpt2_config_path, output_dir)
37+
38+
39+
@pytest.fixture
40+
def original_model(gpt2_config_path: str) -> GPT2LLM:
41+
modalities_config = load_app_config_dict(gpt2_config_path)
42+
return get_model_from_config(modalities_config, model_type=ModelTypeEnum.CHECKPOINTED_MODEL)
43+
44+
45+
@pytest.fixture
46+
def vocab_size(gpt2_config_path: str) -> int:
47+
modalities_config = load_app_config_dict(gpt2_config_path)
48+
return modalities_config["model_raw" if "model_raw" in modalities_config else "model"]["config"]["vocab_size"]
49+
50+
51+
@pytest.fixture
52+
def output_dir(tmp_path: Path) -> Path:
53+
return tmp_path / "output"

0 commit comments

Comments
 (0)