From 04956ad6170e5933ed5017e6be18cc73dd73ac14 Mon Sep 17 00:00:00 2001 From: Andreas Schuh Date: Sun, 25 Jan 2026 21:57:12 +0000 Subject: [PATCH] fix: Type annotation of prepare_batch() --- .flake8 | 2 +- src/deepali/data/prepare.py | 23 ++++++++++------------- src/deepali/data/sample.py | 8 ++++---- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/.flake8 b/.flake8 index 4fd8017..d0483cc 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,4 @@ [flake8] max-line-length = 100 select = C,E,F,W,B,B950 -ignore = E203, E402, E501, W503 +ignore = E203, E402, E501, E704, W503, BLK100 \ No newline at end of file diff --git a/src/deepali/data/prepare.py b/src/deepali/data/prepare.py index 50aba57..439dbaa 100644 --- a/src/deepali/data/prepare.py +++ b/src/deepali/data/prepare.py @@ -2,7 +2,7 @@ from collections import abc from dataclasses import is_dataclass -from typing import Any, Mapping, NamedTuple, Optional, Sequence, Union, overload +from typing import Any, Dict, Mapping, NamedTuple, Optional, Sequence, Union, overload import torch from torch import Tensor @@ -17,36 +17,33 @@ @overload -def prepare_batch(batch: Sequence[Mapping[str, Any]]) -> Mapping[str, Any]: - ... +def prepare_batch(batch: Mapping) -> Dict[str, Any]: ... @overload -def prepare_batch(batch: Sequence[Dataclass]) -> Dataclass: - ... +def prepare_batch(batch: Dataclass) -> Dataclass: ... @overload -def prepare_batch(batch: Sequence[NamedTuple]) -> NamedTuple: - ... +def prepare_batch(batch: NamedTuple) -> NamedTuple: ... def prepare_batch( - batch: Batch, + batch: Union[Mapping, Dataclass, NamedTuple], device: Optional[Union[Device, str]] = None, non_blocking: bool = False, memory_format=torch.preserve_format, ) -> Batch: r"""Move batch data to execution device.""" - names = sample_field_names(batch) + names = sample_field_names(batch) # type: ignore[arg-type] values = [] for name in names: - value = sample_field_value(batch, name) + value = sample_field_value(batch, name) # type: ignore[arg-type] value = prepare_item( value, device=device, non_blocking=non_blocking, memory_format=memory_format ) values.append(value) - return replace_all_sample_field_values(batch, values) + return replace_all_sample_field_values(batch, values) # type: ignore[arg-type] def prepare_item( @@ -56,11 +53,11 @@ def prepare_item( memory_format=torch.preserve_format, ) -> Any: r"""Move batch item data to execution device.""" - kwargs = dict(device=device, non_blocking=non_blocking, memory_format=memory_format) + kwargs: dict = dict(device=device, non_blocking=non_blocking, memory_format=memory_format) if isinstance(value, Tensor): value = value.to(**kwargs) elif isinstance(value, abc.Mapping) or is_dataclass(value) or is_namedtuple(value): - value = prepare_batch(value, **kwargs) + value = prepare_batch(value, **kwargs) # type: ignore[arg-type] elif isinstance(value, Sequence) and not isinstance(value, str): value = [prepare_item(item, **kwargs) for item in value] return value diff --git a/src/deepali/data/sample.py b/src/deepali/data/sample.py index 63a2fcc..aff5e2f 100644 --- a/src/deepali/data/sample.py +++ b/src/deepali/data/sample.py @@ -16,12 +16,12 @@ ) -def sample_field_names(sample: Sample) -> Tuple[str]: +def sample_field_names(sample: Sample) -> Tuple[str, ...]: r"""Get names of fields in data sample.""" if is_dataclass(sample): - return tuple((field.name for field in fields(sample))) + return tuple((field.name for field in fields(sample))) # type: ignore[arg-type] if is_namedtuple(sample): - return sample._fields + return sample._fields # type: ignore[arg-type] if not isinstance(sample, Mapping): raise TypeError("Dataset 'sample' must be dataclass, Mapping, or NamedTuple") return tuple(sample.keys()) @@ -45,7 +45,7 @@ def replace_all_sample_field_values(sample: Sample, values: Sequence[Any]) -> Sa setattr(result, name, value) return result if is_namedtuple(sample): - return sample._replace(**{name: value for name, value in zip(names, values)}) + return sample._replace(**{name: value for name, value in zip(names, values)}) # type: ignore[arg-type] if isinstance(sample, OrderedDict): return OrderedDict([(name, value) for name, value in zip(names, values)]) return {name: value for name, value in zip(names, values)}