diff --git a/cloudpathlib/s3/s3client.py b/cloudpathlib/s3/s3client.py index db130e82..c10aaa71 100644 --- a/cloudpathlib/s3/s3client.py +++ b/cloudpathlib/s3/s3client.py @@ -1,5 +1,7 @@ import mimetypes import os +from contextlib import contextmanager +from functools import partial from pathlib import Path, PurePosixPath from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union @@ -19,6 +21,55 @@ implementation_registry["s3"].dependencies_loaded = False +@contextmanager +def transfer_callable(direction: str, callable: Union[Callable, None], cloud_path: S3Path): + """A context reporting progress when downloading/uploading files from the cloud. + + If callable is `None` then `None` is yielded which disables the `Callable` feature. + Otherwise it will yield a new callable that can be passed to the `Callable` + argument of `S3.Object`'s `download_file` or `upload_file` actions. This callable + is called "periodically" by the `Callable` argument for `S3.Object` `download_file` and `upload_file` + [actions](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/object/index.html#actions). + + The callable must accept (direction, state, cloud_path, bytes_sent) args. + `direction` is either "download" or "upload" string. `state` is a "start", "update" + or "stop" indicating if the download started, progressed or finished. + `cloud_path` is the `S3Path` being transferred. `bytes_sent` is how many bytes + were transferred by this progress report. The sum of `bytes_sent` should add up + to `cloud_path.stat().st_size` and can be used to track transfer progress. + + Before it yields the callable it will call it with state set to "start". This + lets you update UI to indicate that `cloud_path` is starting to transfer. + While the transfer is processing the yielded callable should be called "periodically". + Once the transfer finishes it will call callable with the state "stop". + This lets you update your UI to indicate that the transfer finished. + + + Args: + direction (str): Passed to callable indicating the transfer direction. + Should only be "download" or "upload". + callable (Union[Callable, None]): A callable that is called when a transfer + start, stops or is updated. + cloud_path (S3Path): The s3 resource being transferred. + """ + if callable is None: + # No callable requested, disable the Callable feature without indicating + # start or stop. + yield None + return + + # Indicate that the transfer has started + callable(direction, "start", cloud_path, 0) + + try: + # Return the callable that is passed to `Callable` filling in the arguments + # we require but `S3.Object` don't provide. + yield partial(callable, direction, "update", cloud_path) + finally: + # Indicate that the transfer stopped + callable(direction, "stop", cloud_path, 0) + + @register_client_class("s3") class S3Client(Client): """Client class for AWS S3 which handles authentication with AWS for [`S3Path`](../s3path/) @@ -40,6 +91,7 @@ def __init__( boto3_transfer_config: Optional["TransferConfig"] = None, content_type_method: Optional[Callable] = mimetypes.guess_type, extra_args: Optional[dict] = None, + transfer_callable: Optional[Callable] = None, ): """Class constructor. Sets up a boto3 [`Session`]( https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html). @@ -78,6 +130,9 @@ def __init__( can include any keys supported by upload or download, and we will pass on only the relevant args. To see the extra args that are supported look at the upload and download lists in the [boto3 docs](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.S3Transfer). + transfer_callable (Optional[Callable]): A callable function passed to download and upload functions. + This can be used to show, update and hide transfer progress UI. + See the `transfer_callable` context manager for details. """ endpoint_url = endpoint_url or os.getenv("AWS_ENDPOINT_URL") if boto3_session is not None: @@ -128,6 +183,8 @@ def __init__( } self._endpoint_url = endpoint_url + self.transfer_callable = transfer_callable + super().__init__( local_cache_dir=local_cache_dir, content_type_method=content_type_method, @@ -152,9 +209,13 @@ def _download_file(self, cloud_path: S3Path, local_path: Union[str, os.PathLike] local_path = Path(local_path) obj = self.s3.Object(cloud_path.bucket, cloud_path.key) - obj.download_file( - str(local_path), Config=self.boto3_transfer_config, ExtraArgs=self.boto3_dl_extra_args - ) + with transfer_callable("download", self.transfer_callable, cloud_path) as callback: + obj.download_file( + str(local_path), + Config=self.boto3_transfer_config, + ExtraArgs=self.boto3_dl_extra_args, + Callback=callback, + ) return local_path def _is_file_or_dir(self, cloud_path: S3Path) -> Optional[str]: @@ -355,7 +416,13 @@ def _upload_file(self, local_path: Union[str, os.PathLike], cloud_path: S3Path) if content_encoding is not None: extra_args["ContentEncoding"] = content_encoding - obj.upload_file(str(local_path), Config=self.boto3_transfer_config, ExtraArgs=extra_args) + with transfer_callable("download", self.transfer_callable, cloud_path) as callback: + obj.upload_file( + str(local_path), + Config=self.boto3_transfer_config, + ExtraArgs=extra_args, + Callback=callback, + ) return cloud_path def _get_public_url(self, cloud_path: S3Path) -> str: