11# Copyright The Lightning AI team.
22# Licensed under the Apache License, Version 2.0 (the "License");
3- # you may not use this file except in compliance with the License.
4- # You may obtain a copy of the License at
5- #
63# http://www.apache.org/licenses/LICENSE-2.0
7- #
8- # Unless required by applicable law or agreed to in writing, software
9- # distributed under the License is distributed on an "AS IS" BASIS,
10- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11- # See the License for the specific language governing permissions and
12- # limitations under the License.
134
5+ import functools
146import importlib
7+ import os
8+ import warnings
9+ from collections .abc import Callable
1510from functools import lru_cache
11+ from importlib .metadata import PackageNotFoundError , distribution
12+ from importlib .metadata import version as _version
1613from importlib .util import find_spec
17- from typing import TypeVar
14+ from types import ModuleType
15+ from typing import Any , TypeVar
1816
19- import pkg_resources
17+ from packaging .requirements import Requirement
18+ from packaging .version import InvalidVersion , Version
2019from typing_extensions import ParamSpec
2120
2221T = TypeVar ("T" )
2322P = ParamSpec ("P" )
2423
24+ try :
25+ from importlib import metadata
26+ except ImportError :
27+ # Python < 3.8
28+ import importlib_metadata as metadata # type: ignore
29+
2530
2631@lru_cache
2732def package_available (package_name : str ) -> bool :
@@ -61,6 +66,30 @@ def module_available(module_path: str) -> bool:
6166 return True
6267
6368
69+ def compare_version (package : str , op : Callable , version : str , use_base_version : bool = False ) -> bool :
70+ """Compare package version with some requirements.
71+
72+ >>> compare_version("torch", operator.ge, "0.1")
73+ True
74+ >>> compare_version("does_not_exist", operator.ge, "0.0")
75+ False
76+
77+ """
78+ try :
79+ pkg = importlib .import_module (package )
80+ except (ImportError , RuntimeError ):
81+ return False
82+ try :
83+ # Use importlib.metadata to infer version
84+ pkg_version = Version (pkg .__version__ ) if hasattr (pkg , "__version__" ) else Version (_version (package ))
85+ except (TypeError , PackageNotFoundError ):
86+ # this is mocked by Sphinx, so it should return True to generate all summaries
87+ return True
88+ if use_base_version :
89+ pkg_version = Version (pkg_version .base_version )
90+ return op (pkg_version , Version (version ))
91+
92+
6493class RequirementCache :
6594 """Boolean-like class to check for requirement and module availability.
6695
@@ -80,42 +109,255 @@ class RequirementCache:
80109 True
81110 >>> bool(RequirementCache("unknown_package"))
82111 False
112+ >>> bool(RequirementCache(module="torch.utils"))
113+ True
114+ >>> bool(RequirementCache(module="unknown_package"))
115+ False
116+ >>> bool(RequirementCache(module="unknown.module.path"))
117+ False
83118
84119 """
85120
86- def __init__ (self , requirement : str , module : str | None = None ) -> None :
121+ def __init__ (self , requirement : str | None = None , module : str | None = None ) -> None :
122+ if not (requirement or module ):
123+ raise ValueError ("At least one arguments need to be set." )
87124 self .requirement = requirement
88125 self .module = module
89126
90127 def _check_requirement (self ) -> None :
91- if hasattr ( self , "available" ) :
92- return
128+ if not self . requirement :
129+ raise ValueError ( "Requirement name is required." )
93130 try :
94- # first try the pkg_resources requirement
95- pkg_resources .require (self .requirement )
96- self .available = True
97- self .message = f"Requirement { self .requirement !r} met"
98- except Exception as ex :
131+ req = Requirement (self .requirement )
132+ pkg_version = Version (_version (req .name ))
133+ self .available = req .specifier .contains (pkg_version , prereleases = True ) and (
134+ not req .extras or self ._check_extras_available (req )
135+ )
136+ except (PackageNotFoundError , InvalidVersion ) as ex :
99137 self .available = False
100- self .message = f"{ ex .__class__ .__name__ } : { ex } .\n HINT: Try running `pip install -U { self .requirement !r} `"
101- requirement_contains_version_specifier = any (c in self .requirement for c in "=<>" )
102- if not requirement_contains_version_specifier or self .module is not None :
138+ self .message = f"{ ex .__class__ .__name__ } : { ex } . HINT: Try running `pip install -U { self .requirement !r} `"
139+
140+ if self .available :
141+ self .message = f"Requirement { self .requirement !r} met"
142+ else :
143+ req_include_version = any (c in self .requirement for c in "=<>" )
144+ if not req_include_version or self .module is not None :
103145 module = self .requirement if self .module is None else self .module
104- # sometimes `pkg_resources.require() ` fails but the module is importable
146+ # Sometimes `importlib.metadata.version ` fails but the module is importable
105147 self .available = module_available (module )
106148 if self .available :
107149 self .message = f"Module { module !r} available"
150+ self .message = (
151+ f"Requirement { self .requirement !r} not met. HINT: Try running `pip install -U { self .requirement !r} `"
152+ )
153+
154+ def _check_module (self ) -> None :
155+ if not self .module :
156+ raise ValueError ("Module name is required." )
157+ self .available = module_available (self .module )
158+ if self .available :
159+ self .message = f"Module { self .module !r} available"
160+ else :
161+ self .message = f"Module not found: { self .module !r} . HINT: Try running `pip install -U { self .module } `"
162+
163+ def _check_available (self ) -> None :
164+ if hasattr (self , "available" ):
165+ return
166+ if self .requirement :
167+ self ._check_requirement ()
168+ if getattr (self , "available" , True ) and self .module :
169+ self ._check_module ()
170+
171+ def _check_extras_available (self , requirement : Requirement ) -> bool :
172+ if not requirement .extras :
173+ return True
174+
175+ extra_requirements = self ._get_extra_requirements (requirement )
176+
177+ if not extra_requirements :
178+ # The specified extra is not found in the package metadata
179+ return False
180+
181+ # Verify each extra requirement is installed
182+ for extra_req in extra_requirements :
183+ try :
184+ extra_dist = distribution (extra_req .name )
185+ extra_installed_version = Version (extra_dist .version )
186+ if extra_req .specifier and not extra_req .specifier .contains (extra_installed_version , prereleases = True ):
187+ return False
188+ except importlib .metadata .PackageNotFoundError :
189+ return False
190+
191+ return True
192+
193+ def _get_extra_requirements (self , requirement : Requirement ) -> list [Requirement ]:
194+ dist = distribution (requirement .name )
195+ # Get the required dependencies for the specified extras
196+ extra_requirements = dist .metadata .get_all ("Requires-Dist" ) or []
197+ return [Requirement (r ) for r in extra_requirements if any (extra in r for extra in requirement .extras )]
108198
109199 def __bool__ (self ) -> bool :
110200 """Format as bool."""
111- self ._check_requirement ()
201+ self ._check_available ()
112202 return self .available
113203
114204 def __str__ (self ) -> str :
115205 """Format as string."""
116- self ._check_requirement ()
206+ self ._check_available ()
117207 return self .message
118208
119209 def __repr__ (self ) -> str :
120210 """Format as string."""
121211 return self .__str__ ()
212+
213+
214+ class ModuleAvailableCache (RequirementCache ):
215+ """Boolean-like class for check of module availability.
216+
217+ >>> ModuleAvailableCache("torch")
218+ Module 'torch' available
219+ >>> bool(ModuleAvailableCache("torch.utils"))
220+ True
221+ >>> bool(ModuleAvailableCache("unknown_package"))
222+ False
223+ >>> bool(ModuleAvailableCache("unknown.module.path"))
224+ False
225+
226+ """
227+
228+ def __init__ (self , module : str ) -> None :
229+ warnings .warn (
230+ "`ModuleAvailableCache` is a special case of `RequirementCache`."
231+ " Please use `RequirementCache(module=...)` instead." ,
232+ DeprecationWarning ,
233+ stacklevel = 4 ,
234+ )
235+ super ().__init__ (module = module )
236+
237+
238+ def get_dependency_min_version_spec (package_name : str , dependency_name : str ) -> str :
239+ """Return the minimum version specifier of a dependency of a package.
240+
241+ >>> get_dependency_min_version_spec("pytorch-lightning==1.8.0", "jsonargparse")
242+ '>=4.12.0'
243+
244+ """
245+ dependencies = metadata .requires (package_name ) or []
246+ for dep in dependencies :
247+ dependency = Requirement (dep )
248+ if dependency .name == dependency_name :
249+ spec = [str (s ) for s in dependency .specifier if str (s )[0 ] == ">" ]
250+ return spec [0 ] if spec else ""
251+ raise ValueError (
252+ "This is an internal error. Please file a GitHub issue with the error message. Dependency "
253+ f"{ dependency_name !r} not found in package { package_name !r} ."
254+ )
255+
256+
257+ class LazyModule (ModuleType ):
258+ """Proxy module that lazily imports the underlying module the first time it is actually used.
259+
260+ Args:
261+ module_name: the fully-qualified module name to import
262+ callback: a callback function to call before importing the module
263+
264+ """
265+
266+ def __init__ (self , module_name : str , callback : Callable | None = None ) -> None :
267+ super ().__init__ (module_name )
268+ self ._module : Any = None
269+ self ._callback = callback
270+
271+ def __getattr__ (self , item : str ) -> Any :
272+ """Lazily import the underlying module and delegate attribute access to it."""
273+ if self ._module is None :
274+ self ._import_module ()
275+
276+ return getattr (self ._module , item )
277+
278+ def __dir__ (self ) -> list [str ]:
279+ """Lazily import the underlying module and return its attributes for introspection (dir())."""
280+ if self ._module is None :
281+ self ._import_module ()
282+
283+ return dir (self ._module )
284+
285+ def _import_module (self ) -> None :
286+ # Execute callback, if any
287+ if self ._callback is not None :
288+ self ._callback ()
289+
290+ # Actually import the module
291+ self ._module = importlib .import_module (self .__name__ )
292+
293+ # Update this object's dict so that attribute references are efficient
294+ # (__getattr__ is only called on lookups that fail)
295+ self .__dict__ .update (self ._module .__dict__ )
296+
297+
298+ def lazy_import (module_name : str , callback : Callable | None = None ) -> LazyModule :
299+ """Return a proxy module object that will lazily import the given module the first time it is used.
300+
301+ Example usage:
302+
303+ # Lazy version of `import tensorflow as tf`
304+ tf = lazy_import("tensorflow")
305+ # Other commands
306+ # Now the module is loaded
307+ tf.__version__
308+
309+ Args:
310+ module_name: the fully-qualified module name to import
311+ callback: a callback function to call before importing the module
312+
313+ Returns:
314+ a proxy module object that will be lazily imported when first used
315+
316+ """
317+ return LazyModule (module_name , callback = callback )
318+
319+
320+ def requires (* module_path_version : str , raise_exception : bool = True ) -> Callable [[Callable [P , T ]], Callable [P , T ]]:
321+ """Decorator to check optional dependencies at call time with a clear error/warning message.
322+
323+ Args:
324+ module_path_version: Python module paths (e.g., ``"torch.cuda"``) and/or pip-style requirements
325+ (e.g., ``"torch>=2.0.0"``) to verify.
326+ raise_exception: If ``True``, raise ``ModuleNotFoundError`` when requirements are not satisfied;
327+ otherwise emit a warning and proceed to call the function.
328+
329+ Example:
330+ >>> @requires("libpath", raise_exception=bool(int(os.getenv("LIGHTING_TESTING", "0"))))
331+ ... def my_cwd():
332+ ... from pathlib import Path
333+ ... return Path(__file__).parent
334+
335+ >>> class MyRndPower:
336+ ... @requires("math", "random")
337+ ... def __init__(self):
338+ ... from math import pow
339+ ... from random import randint
340+ ... self._rnd = pow(randint(1, 9), 2)
341+
342+ """
343+
344+ def decorator (func : Callable [P , T ]) -> Callable [P , T ]:
345+ reqs = [
346+ ModuleAvailableCache (mod_ver ) if "." in mod_ver else RequirementCache (mod_ver )
347+ for mod_ver in module_path_version
348+ ]
349+ available = all (map (bool , reqs ))
350+
351+ @functools .wraps (func )
352+ def wrapper (* args : P .args , ** kwargs : P .kwargs ) -> T :
353+ if not available :
354+ missing = os .linesep .join ([repr (r ) for r in reqs if not bool (r )])
355+ msg = f"Required dependencies not available: \n { missing } "
356+ if raise_exception :
357+ raise ModuleNotFoundError (msg )
358+ warnings .warn (msg , stacklevel = 2 )
359+ return func (* args , ** kwargs )
360+
361+ return wrapper
362+
363+ return decorator
0 commit comments