|
1 | | -from kedro.io.partitioned_dataset import PartitionedDataSet |
| 1 | +import warnings |
| 2 | +from copy import deepcopy |
| 3 | +from typing import Any, Callable |
2 | 4 |
|
| 5 | +from kedro.io.core import ( |
| 6 | + VERSION_KEY, |
| 7 | + VERSIONED_FLAG_KEY, |
| 8 | + AbstractDataSet, |
| 9 | + DatasetError, |
| 10 | + parse_dataset_definition, |
| 11 | +) |
| 12 | +from kedro.io.partitioned_dataset import S3_PROTOCOLS |
3 | 13 |
|
4 | | -class Multiset(PartitionedDataSet): |
5 | | - """Modified Partitioned Dataset for pasteur.""" |
| 14 | + |
| 15 | +from urllib.parse import urlparse |
| 16 | +class Multiset(AbstractDataSet): |
| 17 | + # noqa: too-many-instance-attributes,protected-access |
| 18 | + """Simplified version of the partitioned dataset. Is not lazy.""" |
| 19 | + |
| 20 | + def __init__( # noqa: too-many-arguments |
| 21 | + self, |
| 22 | + path: str, |
| 23 | + dataset: str | type[AbstractDataSet] | dict[str, Any], |
| 24 | + filepath_arg: str = "filepath", |
| 25 | + filename_suffix: str = "", |
| 26 | + credentials: dict[str, Any] | None = None, |
| 27 | + load_args: dict[str, Any] | None = None, |
| 28 | + metadata: dict[str, Any] | None = None, |
| 29 | + ): |
| 30 | + # noqa: import-outside-toplevel |
| 31 | + from fsspec.utils import infer_storage_options # for performance reasons |
| 32 | + |
| 33 | + super().__init__() |
| 34 | + |
| 35 | + self._path = path |
| 36 | + self._filename_suffix = filename_suffix |
| 37 | + self._protocol = infer_storage_options(self._path)["protocol"] |
| 38 | + self.metadata = metadata |
| 39 | + |
| 40 | + dataset = dataset if isinstance(dataset, dict) else {"type": dataset} |
| 41 | + self._dataset_type, self._dataset_config = parse_dataset_definition(dataset) |
| 42 | + if VERSION_KEY in self._dataset_config: |
| 43 | + raise DatasetError( |
| 44 | + f"'{self.__class__.__name__}' does not support versioning of the " |
| 45 | + f"underlying dataset. Please remove '{VERSIONED_FLAG_KEY}' flag from " |
| 46 | + f"the dataset definition." |
| 47 | + ) |
| 48 | + |
| 49 | + self._credentials = deepcopy(credentials) or {} |
| 50 | + self._filepath_arg = filepath_arg |
| 51 | + if self._filepath_arg in self._dataset_config: |
| 52 | + warnings.warn( |
| 53 | + f"'{self._filepath_arg}' key must not be specified in the dataset " |
| 54 | + f"definition as it will be overwritten by partition path" |
| 55 | + ) |
| 56 | + |
| 57 | + self._load_args = deepcopy(load_args) or {} |
| 58 | + self._sep = self._filesystem.sep |
| 59 | + # since some filesystem implementations may implement a global cache |
| 60 | + self._invalidate_caches() |
| 61 | + |
| 62 | + @property |
| 63 | + def _filesystem(self): |
| 64 | + # for performance reasons |
| 65 | + import fsspec # noqa: import-outside-toplevel |
| 66 | + |
| 67 | + protocol = "s3" if self._protocol in S3_PROTOCOLS else self._protocol |
| 68 | + return fsspec.filesystem(protocol, **self._credentials) |
| 69 | + |
| 70 | + @property |
| 71 | + def _normalized_path(self) -> str: |
| 72 | + if self._protocol in S3_PROTOCOLS: |
| 73 | + return urlparse(self._path)._replace(scheme="s3").geturl() |
| 74 | + return self._path |
| 75 | + |
| 76 | + def _list_partitions(self) -> list[str]: |
| 77 | + return [ |
| 78 | + path |
| 79 | + for path in self._filesystem.find(self._normalized_path, **self._load_args) |
| 80 | + if path.endswith(self._filename_suffix) |
| 81 | + ] |
| 82 | + |
| 83 | + def _join_protocol(self, path: str) -> str: |
| 84 | + protocol_prefix = f"{self._protocol}://" |
| 85 | + if self._path.startswith(protocol_prefix) and not path.startswith( |
| 86 | + protocol_prefix |
| 87 | + ): |
| 88 | + return f"{protocol_prefix}{path}" |
| 89 | + return path |
| 90 | + |
| 91 | + def _partition_to_path(self, path: str): |
| 92 | + dir_path = self._path.rstrip(self._sep) |
| 93 | + path = path.lstrip(self._sep) |
| 94 | + full_path = self._sep.join([dir_path, path]) + self._filename_suffix |
| 95 | + return full_path |
| 96 | + |
| 97 | + def _path_to_partition(self, path: str) -> str: |
| 98 | + dir_path = self._filesystem._strip_protocol(self._normalized_path) |
| 99 | + path = path.split(dir_path, 1).pop().lstrip(self._sep) |
| 100 | + if self._filename_suffix and path.endswith(self._filename_suffix): |
| 101 | + path = path[: -len(self._filename_suffix)] |
| 102 | + return path |
| 103 | + |
| 104 | + def _load(self) -> dict[str, Callable[[], Any]]: |
| 105 | + partitions = {} |
| 106 | + |
| 107 | + for partition in self._list_partitions(): |
| 108 | + kwargs = deepcopy(self._dataset_config) |
| 109 | + # join the protocol back since PySpark may rely on it |
| 110 | + kwargs[self._filepath_arg] = self._join_protocol(partition) |
| 111 | + dataset = self._dataset_type(**kwargs) # type: ignore |
| 112 | + partition_id = self._path_to_partition(partition) |
| 113 | + partitions[partition_id] = dataset.load() |
| 114 | + |
| 115 | + return partitions |
| 116 | + |
| 117 | + def _save(self, data: dict[str, Any]) -> None: |
| 118 | + for partition_id, partition_data in sorted(data.items()): |
| 119 | + kwargs = deepcopy(self._dataset_config) |
| 120 | + partition = self._partition_to_path(partition_id) |
| 121 | + # join the protocol back since tools like PySpark may rely on it |
| 122 | + kwargs[self._filepath_arg] = self._join_protocol(partition) |
| 123 | + dataset = self._dataset_type(**kwargs) # type: ignore |
| 124 | + if callable(partition_data): |
| 125 | + partition_data = partition_data() # noqa: redefined-loop-name |
| 126 | + dataset.save(partition_data) |
| 127 | + |
| 128 | + self._invalidate_caches() |
| 129 | + |
| 130 | + def _describe(self) -> dict[str, Any]: |
| 131 | + clean_dataset_config = ( |
| 132 | + {k: v for k, v in self._dataset_config.items()} |
| 133 | + if isinstance(self._dataset_config, dict) |
| 134 | + else self._dataset_config |
| 135 | + ) |
| 136 | + return { |
| 137 | + "path": self._path, |
| 138 | + "dataset_type": self._dataset_type.__name__, |
| 139 | + "dataset_config": clean_dataset_config, |
| 140 | + } |
| 141 | + |
| 142 | + def _invalidate_caches(self): |
| 143 | + self._filesystem.invalidate_cache(self._normalized_path) |
6 | 144 |
|
7 | 145 | def reset(self): |
8 | 146 | """Removes the dataset from disk so that there are no stray partitions in subsequent runs.""" |
9 | 147 | if self._filesystem.exists(self._normalized_path): |
10 | 148 | self._filesystem.rm(self._normalized_path, recursive=True, maxdepth=1) |
| 149 | + |
| 150 | + def _release(self) -> None: |
| 151 | + super()._release() |
| 152 | + self._invalidate_caches() |
0 commit comments