Skip to content

Commit 0f5cd79

Browse files
authored
fix: mypy errors (#806)
1 parent e400e0e commit 0f5cd79

2 files changed

Lines changed: 13 additions & 8 deletions

File tree

src/litdata/streaming/fs_provider.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020

2121

2222
class FsProvider(ABC):
23-
def __init__(self, storage_options: dict[str, Any] | None = {}):
23+
def __init__(self, storage_options: dict[str, Any] | None = None):
24+
if storage_options is None:
25+
storage_options = {}
2426
self.storage_options = storage_options
2527

2628
@abstractmethod
@@ -50,7 +52,7 @@ def is_empty(self, path: str) -> bool:
5052

5153

5254
class GCPFsProvider(FsProvider):
53-
def __init__(self, storage_options: dict[str, Any] | None = {}):
55+
def __init__(self, storage_options: dict[str, Any] | None = None):
5456
if not _GOOGLE_STORAGE_AVAILABLE:
5557
raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE))
5658
from google.cloud import storage
@@ -133,9 +135,9 @@ def is_empty(self, path: str) -> bool:
133135

134136

135137
class S3FsProvider(FsProvider):
136-
def __init__(self, storage_options: dict[str, Any] | None = {}):
138+
def __init__(self, storage_options: dict[str, Any] | None = None):
137139
super().__init__(storage_options=storage_options)
138-
self.client = S3Client(storage_options=storage_options)
140+
self.client = S3Client(storage_options=self.storage_options)
139141

140142
def upload_file(self, local_path: str, remote_path: str) -> None:
141143
bucket_name, blob_path = get_bucket_and_path(remote_path, "s3")
@@ -225,11 +227,11 @@ def is_empty(self, path: str) -> bool:
225227

226228

227229
class R2FsProvider(S3FsProvider):
228-
def __init__(self, storage_options: dict[str, Any] | None = {}):
230+
def __init__(self, storage_options: dict[str, Any] | None = None):
229231
super().__init__(storage_options=storage_options)
230232

231233
# Create R2Client with refreshable credentials
232-
self.client = R2Client(storage_options=storage_options)
234+
self.client = R2Client(storage_options=self.storage_options)
233235

234236
def upload_file(self, local_path: str, remote_path: str) -> None:
235237
bucket_name, blob_path = get_bucket_and_path(remote_path, "r2")
@@ -324,7 +326,7 @@ def get_bucket_and_path(remote_filepath: str, expected_scheme: str = "s3") -> tu
324326
return bucket_name, blob_path
325327

326328

327-
def _get_fs_provider(remote_filepath: str, storage_options: dict[str, Any] | None = {}) -> FsProvider:
329+
def _get_fs_provider(remote_filepath: str, storage_options: dict[str, Any] | None = None) -> FsProvider:
328330
obj = parse.urlparse(remote_filepath)
329331
if obj.scheme == "gs":
330332
return GCPFsProvider(storage_options=storage_options)

src/litdata/utilities/parquet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@ def __init__(
2121
self,
2222
dir_path: str | Dir | None,
2323
cache_path: str | None = None,
24-
storage_options: dict | None = {},
24+
storage_options: dict | None = None,
2525
num_workers: int = 4,
2626
):
27+
if storage_options is None:
28+
storage_options = {}
29+
2730
self.dir = _resolve_dir(dir_path)
2831
self.cache_path = cache_path
2932
self.storage_options = storage_options

0 commit comments

Comments
 (0)