-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutil.py
More file actions
229 lines (184 loc) · 7.79 KB
/
util.py
File metadata and controls
229 lines (184 loc) · 7.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
"""Util functions."""
import contextlib
import logging
import multiprocessing
import queue
import sys
import threading
import traceback
from itertools import repeat
from multiprocessing import Pool
import numpy as np
import pandas as pd
from tqdm import tqdm
from tqdm.contrib import DummyTqdmFile
L = logging.getLogger(__name__)
class StreamToQueue(DummyTqdmFile):
"""Fake file-like stream object that redirects all prints to a Queue."""
def __init__(self, *args, message_queue=None, **kwargs):
super().__init__(*args, **kwargs)
self.message_queue = message_queue
def write(self, buf):
"""Redirect write calls to the Queue."""
if len(buf.rstrip()) > 0:
self.message_queue.put(buf)
@contextlib.contextmanager
def _std_out_err_redirect_func(*args, **kwargs):
"""Context manager used to redirect stdout and stderr to tqdm.write()."""
orig_out_err = sys.stdout, sys.stderr
try:
if _std_out_err_redirect_func._redirect: # pylint: disable=protected-access
sys.stdout = StreamToQueue(sys.stdout, *args, **kwargs)
sys.stderr = StreamToQueue(sys.stderr, *args, **kwargs)
yield
# Relay exceptions
except Exception as exc: # pragma: no cover
raise exc
# Always restore sys.stdout/err if necessary
finally:
sys.stdout, sys.stderr = orig_out_err
def _tqdm_wrapper(*args, **kwargs):
# pylint: disable=protected-access
_tqdm_wrapper._tqdm_queue.put(1)
with _std_out_err_redirect_func(message_queue=_tqdm_wrapper._message_queue):
res = try_operation(*args, **kwargs)
return res
def _init_tqdm_global_queue(tqdm_queue, message_queue, redirect=True):
_tqdm_wrapper._tqdm_queue = tqdm_queue # pylint: disable=protected-access
_tqdm_wrapper._message_queue = message_queue # pylint: disable=protected-access
_std_out_err_redirect_func._redirect = redirect # pylint: disable=protected-access
def _reset_tqdm_global_queue():
_tqdm_wrapper._tqdm_queue = None # pylint: disable=protected-access
_tqdm_wrapper._message_queue = None # pylint: disable=protected-access
_std_out_err_redirect_func._redirect = True # pylint: disable=protected-access
def _apply_to_df_internal(data):
(num, df), args, kwargs = data
return num, df.apply(_tqdm_wrapper, axis=1, args=args, **kwargs)
def _restore_object_nulls(result_df, template_df):
"""Restore None values for columns that were object-typed in the input."""
for col, dtype in template_df.dtypes.items():
if col in result_df.columns and pd.api.types.is_object_dtype(dtype):
result_df[col] = result_df[col].astype(object)
result_df.loc[result_df[col].isna(), col] = None
return result_df
def tqdm_worker(progress_bar, tqdm_queue):
"""Update progress bar using the Queue."""
while True:
db = tqdm_queue.get()
if db is None:
return
progress_bar.update(db)
def message_worker(progress_bar, message_queue):
"""Write a message without interfering with the progress bar using the message Queue."""
while True:
message = message_queue.get()
if message is None:
return
progress_bar.write(message)
def apply_to_df(df, func, *args, nb_processes=None, redirect_stdout=None, **kwargs):
"""Apply a function to df rows using tqdm."""
if df.empty:
return df
nb_jobs = len(df)
if redirect_stdout is None:
redirect_stdout = True
if nb_processes is None or nb_processes <= 1:
# Serial computation
is_parallel = False
tqdm_queue = queue.Queue()
message_queue = queue.Queue()
else:
# Parallel computation
is_parallel = True
nb_chunks = min(nb_jobs, nb_processes)
chunks = enumerate(df.loc[i] for i in np.array_split(df.index, nb_chunks))
tqdm_queue = multiprocessing.Queue()
message_queue = multiprocessing.Queue()
# Create the progress bar
progress_bar = tqdm(total=nb_jobs, dynamic_ncols=True)
tqdm_thread = threading.Thread(target=tqdm_worker, args=(progress_bar, tqdm_queue))
message_thread = threading.Thread(target=message_worker, args=(progress_bar, message_queue))
tqdm_thread.start()
message_thread.start()
if is_parallel:
with Pool(
nb_chunks,
_init_tqdm_global_queue,
[tqdm_queue, message_queue, redirect_stdout],
) as pool:
# Start the computation
results_list = pool.imap(
_apply_to_df_internal,
zip(chunks, repeat([func] + list(args)), repeat(kwargs)), # pylint: disable=E0606
)
pool.close()
# Wait for the last jobs to complete
pool.join()
# Gather all results
all_res = pd.concat([j for i, j in sorted(results_list, key=lambda x: x[0])])
else:
_init_tqdm_global_queue(tqdm_queue, message_queue, redirect_stdout)
_, all_res = _apply_to_df_internal(((0, df), [func] + list(args), kwargs))
_reset_tqdm_global_queue()
# Terminate the threads
tqdm_queue.put_nowait(None)
message_queue.put_nowait(None)
tqdm_thread.join()
message_thread.join()
# Close the progress bar
progress_bar.close()
# Return the results
return _restore_object_nulls(all_res, df)
def try_operation(row, func, *args, **kwargs):
"""Try to apply a function on a :class:`pandas.Series`, and record exception."""
if not row.is_valid:
return row
try:
res = func(row, *args, **kwargs)
row.loc[res.keys()] = list(res.values())
except Exception: # pylint: disable=broad-except
exception = "".join(traceback.format_exception(*sys.exc_info()))
L.warning("Exception for ID %s: %s", row.name, exception)
row.exception = exception
row.is_valid = False
row.ret_code = 1
return row
def check_missing_columns(df, required_columns):
"""Return a list of missing columns in a :class:`pandas.DataFrame`.
Args:
df (pandas.DataFrame): The DataFrame to check.
required_columns (list): The list of column names. A column name can be a str for a one
level column or either a list(tuple(str)) or a dict(list(str)) for a two-level column.
"""
missing_cols = []
for col in required_columns: # pylint: disable=too-many-nested-blocks
if isinstance(col, str) and col not in df.columns.get_level_values(0):
missing_cols.append(col)
elif isinstance(col, (list, tuple)) and col not in df.columns:
missing_cols.append(col)
elif isinstance(col, dict):
for level1 in col:
if isinstance(col[level1], str):
col_level2 = [col[level1]]
else:
col_level2 = col[level1]
if level1 not in df.columns.get_level_values(0):
for level2 in col_level2:
missing_cols.append((level1, level2))
else:
for level2 in col_level2:
if level2 not in df[level1].columns.get_level_values(0):
missing_cols.append((level1, level2))
return sorted(set(missing_cols))
def report_missing_columns(df, required_columns):
"""Check that required columns exist in a :class:`pandas.DataFrame`.
Args:
df (pandas.DataFrame): The DataFrame to check.
required_columns (list): The list of column names. A column name can be a str for a one
level column or either a list(tuple(str)) or a dict(list(str)) for a two-level column.
"""
missing_cols = check_missing_columns(df, required_columns)
report_str = f"Missing columns: {sorted(missing_cols)}"
if missing_cols:
return False, report_str, missing_cols
return True, report_str, missing_cols