Skip to content

Commit 6332939

Browse files
authored
fix: Ensure uint64 fields are handled correctly in _create_dataset (#791)
1 parent c8a94a4 commit 6332939

7 files changed

Lines changed: 789 additions & 32 deletions

File tree

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
torch
22
torchvision
33
lightning-utilities
4-
filelock
4+
filelock <3.24 # v3.24.0 removed lock file auto-delete on Windows, breaking our cleanup logic
55
numpy
66
boto3
77
requests

setup.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from importlib.util import module_from_spec, spec_from_file_location
55
from pathlib import Path
66

7-
from pkg_resources import parse_requirements
87
from setuptools import find_packages, setup
98

109
_PATH_ROOT = os.path.dirname(__file__)
@@ -19,8 +18,12 @@ def _load_py_module(fname, pkg="litdata"):
1918
return py
2019

2120

21+
about = _load_py_module("__about__.py")
22+
requirements_module = _load_py_module("requirements.py")
23+
24+
2225
def _load_requirements(path_dir: str = _PATH_ROOT, file_name: str = "requirements.txt") -> list:
23-
reqs = parse_requirements(open(os.path.join(path_dir, file_name)).readlines())
26+
reqs = requirements_module._parse_requirements(open(os.path.join(path_dir, file_name)).readlines())
2427
return list(map(str, reqs))
2528

2629

src/litdata/imports.py

Lines changed: 267 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,32 @@
11
# Copyright The Lightning AI team.
22
# Licensed under the Apache License, Version 2.0 (the "License");
3-
# you may not use this file except in compliance with the License.
4-
# You may obtain a copy of the License at
5-
#
63
# http://www.apache.org/licenses/LICENSE-2.0
7-
#
8-
# Unless required by applicable law or agreed to in writing, software
9-
# distributed under the License is distributed on an "AS IS" BASIS,
10-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11-
# See the License for the specific language governing permissions and
12-
# limitations under the License.
134

5+
import functools
146
import importlib
7+
import os
8+
import warnings
9+
from collections.abc import Callable
1510
from functools import lru_cache
11+
from importlib.metadata import PackageNotFoundError, distribution
12+
from importlib.metadata import version as _version
1613
from importlib.util import find_spec
17-
from typing import TypeVar
14+
from types import ModuleType
15+
from typing import Any, TypeVar
1816

19-
import pkg_resources
17+
from packaging.requirements import Requirement
18+
from packaging.version import InvalidVersion, Version
2019
from typing_extensions import ParamSpec
2120

2221
T = TypeVar("T")
2322
P = ParamSpec("P")
2423

24+
try:
25+
from importlib import metadata
26+
except ImportError:
27+
# Python < 3.8
28+
import importlib_metadata as metadata # type: ignore
29+
2530

2631
@lru_cache
2732
def package_available(package_name: str) -> bool:
@@ -61,6 +66,30 @@ def module_available(module_path: str) -> bool:
6166
return True
6267

6368

69+
def compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool:
70+
"""Compare package version with some requirements.
71+
72+
>>> compare_version("torch", operator.ge, "0.1")
73+
True
74+
>>> compare_version("does_not_exist", operator.ge, "0.0")
75+
False
76+
77+
"""
78+
try:
79+
pkg = importlib.import_module(package)
80+
except (ImportError, RuntimeError):
81+
return False
82+
try:
83+
# Use importlib.metadata to infer version
84+
pkg_version = Version(pkg.__version__) if hasattr(pkg, "__version__") else Version(_version(package))
85+
except (TypeError, PackageNotFoundError):
86+
# this is mocked by Sphinx, so it should return True to generate all summaries
87+
return True
88+
if use_base_version:
89+
pkg_version = Version(pkg_version.base_version)
90+
return op(pkg_version, Version(version))
91+
92+
6493
class RequirementCache:
6594
"""Boolean-like class to check for requirement and module availability.
6695
@@ -80,42 +109,255 @@ class RequirementCache:
80109
True
81110
>>> bool(RequirementCache("unknown_package"))
82111
False
112+
>>> bool(RequirementCache(module="torch.utils"))
113+
True
114+
>>> bool(RequirementCache(module="unknown_package"))
115+
False
116+
>>> bool(RequirementCache(module="unknown.module.path"))
117+
False
83118
84119
"""
85120

86-
def __init__(self, requirement: str, module: str | None = None) -> None:
121+
def __init__(self, requirement: str | None = None, module: str | None = None) -> None:
122+
if not (requirement or module):
123+
raise ValueError("At least one arguments need to be set.")
87124
self.requirement = requirement
88125
self.module = module
89126

90127
def _check_requirement(self) -> None:
91-
if hasattr(self, "available"):
92-
return
128+
if not self.requirement:
129+
raise ValueError("Requirement name is required.")
93130
try:
94-
# first try the pkg_resources requirement
95-
pkg_resources.require(self.requirement)
96-
self.available = True
97-
self.message = f"Requirement {self.requirement!r} met"
98-
except Exception as ex:
131+
req = Requirement(self.requirement)
132+
pkg_version = Version(_version(req.name))
133+
self.available = req.specifier.contains(pkg_version, prereleases=True) and (
134+
not req.extras or self._check_extras_available(req)
135+
)
136+
except (PackageNotFoundError, InvalidVersion) as ex:
99137
self.available = False
100-
self.message = f"{ex.__class__.__name__}: {ex}.\n HINT: Try running `pip install -U {self.requirement!r}`"
101-
requirement_contains_version_specifier = any(c in self.requirement for c in "=<>")
102-
if not requirement_contains_version_specifier or self.module is not None:
138+
self.message = f"{ex.__class__.__name__}: {ex}. HINT: Try running `pip install -U {self.requirement!r}`"
139+
140+
if self.available:
141+
self.message = f"Requirement {self.requirement!r} met"
142+
else:
143+
req_include_version = any(c in self.requirement for c in "=<>")
144+
if not req_include_version or self.module is not None:
103145
module = self.requirement if self.module is None else self.module
104-
# sometimes `pkg_resources.require()` fails but the module is importable
146+
# Sometimes `importlib.metadata.version` fails but the module is importable
105147
self.available = module_available(module)
106148
if self.available:
107149
self.message = f"Module {module!r} available"
150+
self.message = (
151+
f"Requirement {self.requirement!r} not met. HINT: Try running `pip install -U {self.requirement!r}`"
152+
)
153+
154+
def _check_module(self) -> None:
155+
if not self.module:
156+
raise ValueError("Module name is required.")
157+
self.available = module_available(self.module)
158+
if self.available:
159+
self.message = f"Module {self.module!r} available"
160+
else:
161+
self.message = f"Module not found: {self.module!r}. HINT: Try running `pip install -U {self.module}`"
162+
163+
def _check_available(self) -> None:
164+
if hasattr(self, "available"):
165+
return
166+
if self.requirement:
167+
self._check_requirement()
168+
if getattr(self, "available", True) and self.module:
169+
self._check_module()
170+
171+
def _check_extras_available(self, requirement: Requirement) -> bool:
172+
if not requirement.extras:
173+
return True
174+
175+
extra_requirements = self._get_extra_requirements(requirement)
176+
177+
if not extra_requirements:
178+
# The specified extra is not found in the package metadata
179+
return False
180+
181+
# Verify each extra requirement is installed
182+
for extra_req in extra_requirements:
183+
try:
184+
extra_dist = distribution(extra_req.name)
185+
extra_installed_version = Version(extra_dist.version)
186+
if extra_req.specifier and not extra_req.specifier.contains(extra_installed_version, prereleases=True):
187+
return False
188+
except importlib.metadata.PackageNotFoundError:
189+
return False
190+
191+
return True
192+
193+
def _get_extra_requirements(self, requirement: Requirement) -> list[Requirement]:
194+
dist = distribution(requirement.name)
195+
# Get the required dependencies for the specified extras
196+
extra_requirements = dist.metadata.get_all("Requires-Dist") or []
197+
return [Requirement(r) for r in extra_requirements if any(extra in r for extra in requirement.extras)]
108198

109199
def __bool__(self) -> bool:
110200
"""Format as bool."""
111-
self._check_requirement()
201+
self._check_available()
112202
return self.available
113203

114204
def __str__(self) -> str:
115205
"""Format as string."""
116-
self._check_requirement()
206+
self._check_available()
117207
return self.message
118208

119209
def __repr__(self) -> str:
120210
"""Format as string."""
121211
return self.__str__()
212+
213+
214+
class ModuleAvailableCache(RequirementCache):
215+
"""Boolean-like class for check of module availability.
216+
217+
>>> ModuleAvailableCache("torch")
218+
Module 'torch' available
219+
>>> bool(ModuleAvailableCache("torch.utils"))
220+
True
221+
>>> bool(ModuleAvailableCache("unknown_package"))
222+
False
223+
>>> bool(ModuleAvailableCache("unknown.module.path"))
224+
False
225+
226+
"""
227+
228+
def __init__(self, module: str) -> None:
229+
warnings.warn(
230+
"`ModuleAvailableCache` is a special case of `RequirementCache`."
231+
" Please use `RequirementCache(module=...)` instead.",
232+
DeprecationWarning,
233+
stacklevel=4,
234+
)
235+
super().__init__(module=module)
236+
237+
238+
def get_dependency_min_version_spec(package_name: str, dependency_name: str) -> str:
239+
"""Return the minimum version specifier of a dependency of a package.
240+
241+
>>> get_dependency_min_version_spec("pytorch-lightning==1.8.0", "jsonargparse")
242+
'>=4.12.0'
243+
244+
"""
245+
dependencies = metadata.requires(package_name) or []
246+
for dep in dependencies:
247+
dependency = Requirement(dep)
248+
if dependency.name == dependency_name:
249+
spec = [str(s) for s in dependency.specifier if str(s)[0] == ">"]
250+
return spec[0] if spec else ""
251+
raise ValueError(
252+
"This is an internal error. Please file a GitHub issue with the error message. Dependency "
253+
f"{dependency_name!r} not found in package {package_name!r}."
254+
)
255+
256+
257+
class LazyModule(ModuleType):
258+
"""Proxy module that lazily imports the underlying module the first time it is actually used.
259+
260+
Args:
261+
module_name: the fully-qualified module name to import
262+
callback: a callback function to call before importing the module
263+
264+
"""
265+
266+
def __init__(self, module_name: str, callback: Callable | None = None) -> None:
267+
super().__init__(module_name)
268+
self._module: Any = None
269+
self._callback = callback
270+
271+
def __getattr__(self, item: str) -> Any:
272+
"""Lazily import the underlying module and delegate attribute access to it."""
273+
if self._module is None:
274+
self._import_module()
275+
276+
return getattr(self._module, item)
277+
278+
def __dir__(self) -> list[str]:
279+
"""Lazily import the underlying module and return its attributes for introspection (dir())."""
280+
if self._module is None:
281+
self._import_module()
282+
283+
return dir(self._module)
284+
285+
def _import_module(self) -> None:
286+
# Execute callback, if any
287+
if self._callback is not None:
288+
self._callback()
289+
290+
# Actually import the module
291+
self._module = importlib.import_module(self.__name__)
292+
293+
# Update this object's dict so that attribute references are efficient
294+
# (__getattr__ is only called on lookups that fail)
295+
self.__dict__.update(self._module.__dict__)
296+
297+
298+
def lazy_import(module_name: str, callback: Callable | None = None) -> LazyModule:
299+
"""Return a proxy module object that will lazily import the given module the first time it is used.
300+
301+
Example usage:
302+
303+
# Lazy version of `import tensorflow as tf`
304+
tf = lazy_import("tensorflow")
305+
# Other commands
306+
# Now the module is loaded
307+
tf.__version__
308+
309+
Args:
310+
module_name: the fully-qualified module name to import
311+
callback: a callback function to call before importing the module
312+
313+
Returns:
314+
a proxy module object that will be lazily imported when first used
315+
316+
"""
317+
return LazyModule(module_name, callback=callback)
318+
319+
320+
def requires(*module_path_version: str, raise_exception: bool = True) -> Callable[[Callable[P, T]], Callable[P, T]]:
321+
"""Decorator to check optional dependencies at call time with a clear error/warning message.
322+
323+
Args:
324+
module_path_version: Python module paths (e.g., ``"torch.cuda"``) and/or pip-style requirements
325+
(e.g., ``"torch>=2.0.0"``) to verify.
326+
raise_exception: If ``True``, raise ``ModuleNotFoundError`` when requirements are not satisfied;
327+
otherwise emit a warning and proceed to call the function.
328+
329+
Example:
330+
>>> @requires("libpath", raise_exception=bool(int(os.getenv("LIGHTING_TESTING", "0"))))
331+
... def my_cwd():
332+
... from pathlib import Path
333+
... return Path(__file__).parent
334+
335+
>>> class MyRndPower:
336+
... @requires("math", "random")
337+
... def __init__(self):
338+
... from math import pow
339+
... from random import randint
340+
... self._rnd = pow(randint(1, 9), 2)
341+
342+
"""
343+
344+
def decorator(func: Callable[P, T]) -> Callable[P, T]:
345+
reqs = [
346+
ModuleAvailableCache(mod_ver) if "." in mod_ver else RequirementCache(mod_ver)
347+
for mod_ver in module_path_version
348+
]
349+
available = all(map(bool, reqs))
350+
351+
@functools.wraps(func)
352+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
353+
if not available:
354+
missing = os.linesep.join([repr(r) for r in reqs if not bool(r)])
355+
msg = f"Required dependencies not available: \n{missing}"
356+
if raise_exception:
357+
raise ModuleNotFoundError(msg)
358+
warnings.warn(msg, stacklevel=2)
359+
return func(*args, **kwargs)
360+
361+
return wrapper
362+
363+
return decorator

0 commit comments

Comments
 (0)