|
7 | 7 |
|
8 | 8 | import gzip |
9 | 9 | import inspect |
| 10 | +import itertools |
10 | 11 | import os |
11 | 12 | import platform |
12 | 13 | import sys |
13 | 14 | import warnings |
| 15 | +from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait |
14 | 16 | from copy import copy |
15 | 17 | from functools import cache |
16 | 18 | from importlib import import_module |
@@ -831,6 +833,83 @@ def _submit_requests( # noqa |
831 | 833 |
|
832 | 834 | return total_data |
833 | 835 |
|
| 836 | + # this is here as a separate function to allow for multithreading when querying s3 buckets |
| 837 | + # which is necessary to speed up retrieval of large data dumps |
| 838 | + def _multi_thread( |
| 839 | + self, |
| 840 | + func: Callable, |
| 841 | + params_list: list[dict], |
| 842 | + progress_bar: tqdm | None = None, |
| 843 | + ): |
| 844 | + """Handles setting up a threadpool and sending parallel requests. |
| 845 | +
|
| 846 | + Arguments: |
| 847 | + func (Callable): Callable function to multi |
| 848 | + params_list (list): list of dictionaries containing url and params for each request |
| 849 | + progress_bar (tqdm): progress bar to update with progress |
| 850 | +
|
| 851 | + Returns: |
| 852 | + Tuples with data, total number of docs in matching the query in the database, |
| 853 | + and the index of the criteria dictionary in the provided parameter list |
| 854 | + """ |
| 855 | + return_data = [] |
| 856 | + |
| 857 | + params_gen = iter( |
| 858 | + params_list |
| 859 | + ) # Iter necessary for islice to keep track of what has been accessed |
| 860 | + |
| 861 | + params_ind = 0 |
| 862 | + |
| 863 | + with ThreadPoolExecutor( |
| 864 | + max_workers=MAPI_CLIENT_SETTINGS.NUM_PARALLEL_REQUESTS # type: ignore |
| 865 | + ) as executor: |
| 866 | + # Get list of initial futures defined by max number of parallel requests |
| 867 | + futures = set() |
| 868 | + |
| 869 | + for params in itertools.islice( |
| 870 | + params_gen, |
| 871 | + MAPI_CLIENT_SETTINGS.NUM_PARALLEL_REQUESTS, # type: ignore |
| 872 | + ): |
| 873 | + future = executor.submit( |
| 874 | + func, |
| 875 | + **params, |
| 876 | + ) |
| 877 | + |
| 878 | + future.crit_ind = params_ind # type: ignore |
| 879 | + futures.add(future) |
| 880 | + params_ind += 1 |
| 881 | + |
| 882 | + while futures: |
| 883 | + # Wait for at least one future to complete and process finished |
| 884 | + finished, futures = wait(futures, return_when=FIRST_COMPLETED) |
| 885 | + |
| 886 | + for future in finished: |
| 887 | + data, subtotal = future.result() |
| 888 | + |
| 889 | + if progress_bar is not None: |
| 890 | + if isinstance(data, dict): |
| 891 | + size = len(data["data"]) |
| 892 | + elif isinstance(data, list): |
| 893 | + size = len(data) |
| 894 | + else: |
| 895 | + size = 1 |
| 896 | + progress_bar.update(size) |
| 897 | + |
| 898 | + return_data.append((data, subtotal, future.crit_ind)) # type: ignore |
| 899 | + |
| 900 | + # Populate more futures to replace finished |
| 901 | + for params in itertools.islice(params_gen, len(finished)): |
| 902 | + new_future = executor.submit( |
| 903 | + func, |
| 904 | + **params, |
| 905 | + ) |
| 906 | + |
| 907 | + new_future.crit_ind = params_ind # type: ignore |
| 908 | + futures.add(new_future) |
| 909 | + params_ind += 1 |
| 910 | + |
| 911 | + return return_data |
| 912 | + |
834 | 913 | def _submit_request_and_process( |
835 | 914 | self, |
836 | 915 | url: str, |
|
0 commit comments