Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 71 additions & 4 deletions cloudpathlib/s3/s3client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import mimetypes
import os
from contextlib import contextmanager
from functools import partial
from pathlib import Path, PurePosixPath
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

Expand All @@ -19,6 +21,55 @@
implementation_registry["s3"].dependencies_loaded = False


@contextmanager
def transfer_callable(direction: str, callable: Union[Callable, None], cloud_path: S3Path):
"""A context reporting progress when downloading/uploading files from the cloud.

If callable is `None` then `None` is yielded which disables the `Callable` feature.
Otherwise it will yield a new callable that can be passed to the `Callable`
argument of `S3.Object`'s `download_file` or `upload_file` actions. This callable
is called "periodically" by the `Callable` argument for `S3.Object` `download_file` and `upload_file`
[actions](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/object/index.html#actions).

The callable must accept (direction, state, cloud_path, bytes_sent) args.
`direction` is either "download" or "upload" string. `state` is a "start", "update"
or "stop" indicating if the download started, progressed or finished.
`cloud_path` is the `S3Path` being transferred. `bytes_sent` is how many bytes
were transferred by this progress report. The sum of `bytes_sent` should add up
to `cloud_path.stat().st_size` and can be used to track transfer progress.

Before it yields the callable it will call it with state set to "start". This
lets you update UI to indicate that `cloud_path` is starting to transfer.
While the transfer is processing the yielded callable should be called "periodically".
Once the transfer finishes it will call callable with the state "stop".
This lets you update your UI to indicate that the transfer finished.


Args:
direction (str): Passed to callable indicating the transfer direction.
Should only be "download" or "upload".
callable (Union[Callable, None]): A callable that is called when a transfer
start, stops or is updated.
cloud_path (S3Path): The s3 resource being transferred.
"""
if callable is None:
# No callable requested, disable the Callable feature without indicating
# start or stop.
yield None
return

# Indicate that the transfer has started
callable(direction, "start", cloud_path, 0)

try:
# Return the callable that is passed to `Callable` filling in the arguments
# we require but `S3.Object` don't provide.
yield partial(callable, direction, "update", cloud_path)
finally:
# Indicate that the transfer stopped
callable(direction, "stop", cloud_path, 0)


@register_client_class("s3")
class S3Client(Client):
"""Client class for AWS S3 which handles authentication with AWS for [`S3Path`](../s3path/)
Expand All @@ -40,6 +91,7 @@ def __init__(
boto3_transfer_config: Optional["TransferConfig"] = None,
content_type_method: Optional[Callable] = mimetypes.guess_type,
extra_args: Optional[dict] = None,
transfer_callable: Optional[Callable] = None,
):
"""Class constructor. Sets up a boto3 [`Session`](
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html).
Expand Down Expand Up @@ -78,6 +130,9 @@ def __init__(
can include any keys supported by upload or download, and we will pass on only the relevant args. To see the extra
args that are supported look at the upload and download lists in the
[boto3 docs](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.S3Transfer).
transfer_callable (Optional[Callable]): A callable function passed to download and upload functions.
This can be used to show, update and hide transfer progress UI.
See the `transfer_callable` context manager for details.
"""
endpoint_url = endpoint_url or os.getenv("AWS_ENDPOINT_URL")
if boto3_session is not None:
Expand Down Expand Up @@ -128,6 +183,8 @@ def __init__(
}
self._endpoint_url = endpoint_url

self.transfer_callable = transfer_callable

super().__init__(
local_cache_dir=local_cache_dir,
content_type_method=content_type_method,
Expand All @@ -152,9 +209,13 @@ def _download_file(self, cloud_path: S3Path, local_path: Union[str, os.PathLike]
local_path = Path(local_path)
obj = self.s3.Object(cloud_path.bucket, cloud_path.key)

obj.download_file(
str(local_path), Config=self.boto3_transfer_config, ExtraArgs=self.boto3_dl_extra_args
)
with transfer_callable("download", self.transfer_callable, cloud_path) as callback:
obj.download_file(
str(local_path),
Config=self.boto3_transfer_config,
ExtraArgs=self.boto3_dl_extra_args,
Callback=callback,
)
return local_path

def _is_file_or_dir(self, cloud_path: S3Path) -> Optional[str]:
Expand Down Expand Up @@ -355,7 +416,13 @@ def _upload_file(self, local_path: Union[str, os.PathLike], cloud_path: S3Path)
if content_encoding is not None:
extra_args["ContentEncoding"] = content_encoding

obj.upload_file(str(local_path), Config=self.boto3_transfer_config, ExtraArgs=extra_args)
with transfer_callable("download", self.transfer_callable, cloud_path) as callback:
obj.upload_file(
str(local_path),
Config=self.boto3_transfer_config,
ExtraArgs=extra_args,
Callback=callback,
)
return cloud_path

def _get_public_url(self, cloud_path: S3Path) -> str:
Expand Down