Skip to content

Commit d391fce

Browse files
committed
fix version check
1 parent 7029d07 commit d391fce

4 files changed

Lines changed: 37 additions & 11 deletions

File tree

defuser/defuser.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
from defuser.modeling.update_module import update_module
1313
from defuser.utils.common import (
1414
MIN_SUPPORTED_TRANSFORMERS_VERSION,
15+
is_version_at_least,
1516
is_supported_transformers_version,
1617
warn_if_public_api_transformers_unsupported,
1718
)
18-
from packaging import version
1919
import transformers
2020
from logbar import LogBar
2121

@@ -69,7 +69,7 @@ def replace_fused_blocks(model_type: str) -> bool:
6969
custom_class = getattr(custom_module, custom_class_name)
7070
setattr(orig_module, orig_class_name, custom_class)
7171

72-
if version.parse(transformers.__version__) >= version.parse(MIN_SUPPORTED_TRANSFORMERS_VERSION):
72+
if is_version_at_least(transformers.__version__, MIN_SUPPORTED_TRANSFORMERS_VERSION):
7373
from transformers import conversion_mapping
7474

7575
if not hasattr(conversion_mapping, "orig_get_checkpoint_conversion_mapping"):
@@ -102,8 +102,7 @@ def check_model_compatibility(model: nn.Module) -> bool:
102102
return False
103103

104104
min_ver = MODEL_CONFIG[model_type].get("min_transformers_version")
105-
current_ver = version.parse(transformers.__version__)
106-
if min_ver and current_ver < version.parse(min_ver):
105+
if min_ver and not is_version_at_least(transformers.__version__, min_ver):
107106
logger.warn(
108107
f"Skip conversion for model_type={model_type}: "
109108
f"requires transformers>={min_ver}, current version is {transformers.__version__}."

defuser/utils/common.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,35 @@ class ModuleNameFilter:
2828
negative: tuple[pcre.Pattern, ...]
2929

3030

31+
def _parse_version(value: str | version.Version) -> version.Version:
32+
"""Return a normalized packaging version object."""
33+
if isinstance(value, version.Version):
34+
return value
35+
return version.parse(value)
36+
37+
38+
def is_version_at_least(
39+
installed_version: str | version.Version,
40+
minimum_version: str | version.Version,
41+
) -> bool:
42+
"""Return whether a version meets a minimum, allowing same-release dev snapshots.
43+
44+
Hugging Face main-branch builds report versions like ``5.3.0-dev`` which
45+
packaging normalizes to ``5.3.0.dev0`` and orders before ``5.3.0``. Defuser
46+
treats those dev snapshots as satisfying the corresponding stable floor.
47+
"""
48+
installed = _parse_version(installed_version)
49+
minimum = _parse_version(minimum_version)
50+
51+
if installed >= minimum:
52+
return True
53+
54+
if installed.is_devrelease:
55+
return version.parse(installed.base_version) >= minimum
56+
57+
return False
58+
59+
3160
def env_flag(name: str, default: str | bool | None = "0") -> bool:
3261
"""Return ``True`` when an env var is set to a truthy value."""
3362

@@ -46,14 +75,14 @@ def is_transformers_version_greater_or_equal_5():
4675
"""Cache the coarse ``transformers>=5`` capability check used by fast paths."""
4776
import transformers
4877

49-
return version.parse(transformers.__version__) >= version.parse("5.0.0")
78+
return is_version_at_least(transformers.__version__, "5.0.0")
5079

5180

5281
def is_supported_transformers_version() -> bool:
5382
"""Return whether the installed transformers version is supported by Defuser's public API."""
5483
import transformers
5584

56-
return version.parse(transformers.__version__) >= version.parse(MIN_SUPPORTED_TRANSFORMERS_VERSION)
85+
return is_version_at_least(transformers.__version__, MIN_SUPPORTED_TRANSFORMERS_VERSION)
5786

5887

5988
def warn_if_public_api_transformers_unsupported(api_name: str, logger) -> bool:

defuser/utils/hf.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
import torch
1313
import transformers
1414
from logbar import LogBar
15-
from packaging import version
1615
from transformers import AutoConfig
1716

1817
from defuser.model_registry import MODEL_CONFIG
19-
from defuser.utils.common import env_flag, warn_if_public_api_transformers_unsupported
18+
from defuser.utils.common import env_flag, is_version_at_least, warn_if_public_api_transformers_unsupported
2019

2120
logger = LogBar(__name__)
2221

@@ -77,8 +76,7 @@ def pre_check_config(model_name: str | torch.nn.Module):
7776
cfg = MODEL_CONFIG[model_type]
7877

7978
min_ver = cfg.get("min_transformers_version")
80-
tf_ver = version.parse(transformers.__version__)
81-
if min_ver and tf_ver < version.parse(min_ver):
79+
if min_ver and not is_version_at_least(transformers.__version__, min_ver):
8280
return False
8381
try:
8482
file_path = get_file_path_via_model_name(model_name, "model.safetensors.index.json")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"
99

1010
[project]
1111
name = "Defuser"
12-
version = "0.0.16"
12+
version = "0.0.17"
1313
description = "Model defuser helper for HF Transformers."
1414
readme = "README.md"
1515
requires-python = ">=3.9"

0 commit comments

Comments
 (0)