-
Notifications
You must be signed in to change notification settings - Fork 120
Expand file tree
/
Copy pathtest_utils.py
More file actions
82 lines (63 loc) · 2.88 KB
/
test_utils.py
File metadata and controls
82 lines (63 loc) · 2.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from __future__ import annotations
from pathlib import Path
from unittest.mock import patch
import pytest
from model2vec.distill.utils import select_optimal_device
from model2vec.modelcards import get_metadata_from_readme
from model2vec.utils import get_package_extras, importable
def test__get_metadata_from_readme_not_exists() -> None:
"""Test getting metadata from a README."""
assert get_metadata_from_readme(Path("zzz")) == {}
def test__get_metadata_from_readme_mocked_file(tmp_path: Path) -> None:
"""Test getting metadata from a README."""
path = tmp_path / "README.md"
path.write_text("---\nkey: value\n---\n", encoding="utf-8")
assert get_metadata_from_readme(path)["key"] == "value"
def test__get_metadata_from_readme_mocked_file_keys(tmp_path: Path) -> None:
"""Test getting metadata from a README."""
path = tmp_path / "README.md"
path.write_text("b", encoding="utf-8")
assert set(get_metadata_from_readme(path)) == set()
@pytest.mark.parametrize(
"torch_version, device, expected, cuda, mps, should_raise",
[
("2.7.0", "cpu", "cpu", True, True, False),
("2.8.0", "cpu", "cpu", True, True, False),
("2.7.0", "clown", "clown", False, False, False),
("2.8.0", "clown", "clown", False, False, False),
("2.7.0", "mps", "mps", False, True, False),
("2.8.0", "mps", None, False, True, True),
("2.7.0", None, "cuda", True, True, False),
("2.7.0", None, "mps", False, True, False),
("2.7.0", None, "cpu", False, False, False),
("2.8.0", None, "cuda", True, True, False),
("2.8.0", None, "cpu", False, True, False),
("2.8.0", None, "cpu", False, False, False),
("2.9.0", None, "cpu", False, True, False),
("3.0.0", None, "cpu", False, True, False),
],
)
def test_select_optimal_device(torch_version, device, expected, cuda, mps, should_raise) -> None:
"""Test whether the optimal device is selected across versions and backends."""
with (
patch("torch.cuda.is_available", return_value=cuda),
patch("torch.backends.mps.is_available", return_value=mps),
patch("torch.__version__", torch_version),
):
if should_raise:
with pytest.raises(RuntimeError):
select_optimal_device(device)
else:
assert select_optimal_device(device) == expected
def test_importable() -> None:
"""Test the importable function."""
with pytest.raises(ImportError):
importable("clown", "clown")
importable("os", "clown")
def test_get_package_extras() -> None:
"""Test package extras."""
extras = set(get_package_extras("model2vec", "distill"))
assert extras == {"skeletoken", "torch", "transformers", "scikit-learn"}
def test_get_package_extras_empty() -> None:
"""Test package extras with an empty package."""
assert not list(get_package_extras("tqdm", ""))