-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathutils.py
More file actions
336 lines (259 loc) · 12 KB
/
utils.py
File metadata and controls
336 lines (259 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
from __future__ import annotations
import asyncio
import contextlib
import functools
import io
import sys
from collections.abc import Awaitable, Callable, Iterator, Sequence
from pathlib import Path
from typing import Any, Literal, ParamSpec, TextIO, TypeVar, overload
import numpy.typing as npt
import pandas as pd
import scanpy as sc
from loguru import logger
from como.data_types import LOG_FORMAT, Algorithm, LogLevel
from como.pipelines.identifier import get_remaining_identifiers
P = ParamSpec("P")
T = TypeVar("T")
__all__ = [
"asyncable",
"get_missing_gene_data",
"num_columns",
"num_rows",
"read_file",
"set_up_logging",
"split_gene_expression_data",
"stringlist_to_list",
"suppress_stdout",
]
def stringlist_to_list(stringlist: str | list[str]) -> list[str]:
"""Convert a string from the command line into a Python list.
In doing so, we must deprecate the use of the current method
If '[' and ']' are present in the first and last items of the list,
assume we are using the "old" method of providing context names
Args:
stringlist: The "string list" gathered from the command line. Example input: "['mat', 'xml', 'json']"
Returns:
A list of strings. Example output: ['mat', 'xml', 'json']
"""
if isinstance(stringlist, list):
return stringlist
if not (stringlist.startswith("[") and stringlist.endswith("]")):
return stringlist.split(" ")
# Remove any brackets from the first and last items; replace quotation marks and commas with nothing
new_list: list[str] = stringlist.strip("[]").replace("'", "").replace(" ", "").split(",")
# Show a warning if more than one item is present in the list (this means we are using the old method)
logger.critical(
"DeprecationWarning: Please use the new method of providing context names, "
"i.e. --output-filetypes 'type1 type2 type3'."
)
logger.critical(
"If you are using COMO, this can be done by setting the 'context_names' variable to a "
"simple string separated by spaces. Here are a few examples!"
)
logger.critical("context_names = 'cellType1 cellType2 cellType3'")
logger.critical("output_filetypes = 'output1 output2 output3'")
logger.critical(
"\nYour current method of passing context names will be removed in the future. "
"Update your variables above accordingly!\n\n"
)
return new_list
def split_gene_expression_data(
expression_data: pd.DataFrame,
identifier_column: Literal["ensembl_gene_id", "entrez_gene_id"],
recon_algorithm: Algorithm | None = None,
*,
ensembl_as_index: bool = True,
) -> pd.DataFrame:
"""Split the gene expression data into single-gene and multiple-gene names.
Arg:
expression_data: The gene expression data to map
identifier_column: The column containing the gene identifiers, either 'ensembl_gene_id'
recon_algorithm: The recon algorithm used to generate the gene expression data
ensembl_as_index: Should the 'ensembl_gene_id' column be set as
Returns:
A pandas DataFrame with the split gene expression data
"""
expression_data.columns = [c.lower() for c in expression_data.columns]
if "combine_z" in expression_data.columns:
expression_data.rename(columns={"combine_z": "active"}, inplace=True)
expression_data = expression_data[[identifier_column, "active"]]
single_gene_names = expression_data[~expression_data[identifier_column].astype(str).str.contains("//")]
multiple_gene_names = expression_data[expression_data[identifier_column].astype(str).str.contains("//")]
split_gene_names = multiple_gene_names.assign(
ensembl_gene_id=multiple_gene_names[identifier_column].astype(str).str.split("///")
).explode(identifier_column)
gene_expressions = pd.concat([single_gene_names, split_gene_names], axis=0, ignore_index=True)
if ensembl_as_index:
gene_expressions.set_index(identifier_column, inplace=True)
return gene_expressions
@contextlib.contextmanager
def suppress_stdout() -> Iterator[None]:
"""Suppress stdout output from the current context.
:return: The context manager
"""
with io.StringIO() as buffer:
try:
sys.stdout = buffer
yield
finally:
sys.stdout = sys.__stdout__
def get_missing_gene_data(values: Sequence[str] | pd.DataFrame | sc.AnnData, taxon_id: int | str) -> pd.DataFrame:
"""Get missing gene data from a given set of values.
This function will attempt to find gene identifiers in the provided values and convert them to a DataFrame
containing "entrez_gene_id", "ensembl_gene_id", and "gene_symbol"
:param values: The values to extract gene identifiers from.
This can be a list of strings, a pandas DataFrame, or a scanpy AnnData object.
:param taxon_id: The taxonomic identifier to use for gene consversion
:return: A DataFrame containing "entrez_gene_id", "ensembl_gene_id", and "gene_symbol" for the provided values
"""
# second isinstance required for static type check to be happy
# if isinstance(values, list) and not isinstance(values, pd.DataFrame):
if isinstance(values, list):
return get_remaining_identifiers(ids=values, taxon=taxon_id)
elif isinstance(values, pd.DataFrame):
# raise error if duplicate column names exist
if any(values.columns.duplicated(keep=False)):
duplicate_cols = values.columns[values.columns.duplicated(keep=False)].unique().tolist()
raise ValueError(
f"Duplicate column names exist! This will result in an error processing data. "
f"Duplicates: {','.join(duplicate_cols)}"
)
names: list[str] = values.columns.tolist()
if values.index.name is not None:
names.append(str(values.index.name))
if "gene_symbol" in names:
return get_missing_gene_data(
values["gene_symbol"].tolist() if "gene_symbol" in values.columns else values.index.tolist(),
taxon_id=taxon_id,
)
elif "entrez_gene_id" in names:
return get_missing_gene_data(
values["entrez_gene_id"].tolist() if "entrez_gene_id" in values.columns else values.index.tolist(),
taxon_id=taxon_id,
)
elif "ensembl_gene_id" in names:
return get_missing_gene_data(
values["ensembl_gene_id"].tolist() if "ensembl_gene_id" in values.columns else values.index.tolist(),
taxon_id=taxon_id,
)
else:
raise ValueError(
"Unable to find 'gene_symbol', 'entrez_gene_id', or 'ensembl_gene_id' in the input matrix."
)
else:
raise ValueError(f"Values must be a list of strings or a pandas DataFrame, got: {type(values)}")
@overload
def read_file(path: None, h5ad_as_df: bool = True, **kwargs: Any) -> None: ...
@overload
def read_file(path: pd.DataFrame, h5ad_as_df: bool = True, **kwargs: Any) -> pd.DataFrame: ...
@overload
def read_file(path: io.StringIO, h5ad_as_df: bool = True, **kwargs: Any) -> pd.DataFrame: ...
@overload
def read_file(path: sc.AnnData, h5ad_as_df: Literal[False], **kwargs: Any) -> sc.AnnData: ...
@overload
def read_file(path: sc.AnnData, h5ad_as_df: Literal[True] = True, **kwargs: Any) -> pd.DataFrame: ...
@overload
def read_file(path: Path, h5ad_as_df: Literal[False], **kwargs: Any) -> pd.DataFrame | sc.AnnData: ...
@overload
def read_file(path: Path, h5ad_as_df: Literal[True] = True, **kwargs: Any) -> pd.DataFrame: ...
def read_file( # noqa: C901
path: Path | io.StringIO | pd.DataFrame | sc.AnnData | None,
h5ad_as_df: bool = True,
**kwargs: Any,
) -> pd.DataFrame | sc.AnnData | None:
"""Read a filepath and return pandas.DataFrame or scanpy.AnnData.
If the provided path is None, None will also be returned.
None may be provided to this function so that `asyncio.gather` can safely be used on all sources
(trna, mrna, scrna, proteomics) without needing to check if the user has provided those sources
Args:
path: The path to read from
h5ad_as_df: If True and the file is an h5ad, return a pandas DataFrame of the .X matrix
kwargs: Additional arguments to pass to pandas.read_csv, pandas.read_excel, or scanpy.read_h5ad
Returns:
None, or a pandas DataFrame or AnnData
"""
if isinstance(path, pd.DataFrame):
return path
elif isinstance(path, sc.AnnData):
return path.to_df().T if h5ad_as_df else path
elif isinstance(path, io.StringIO):
return pd.read_csv(path, **kwargs)
elif not path:
return None
if isinstance(path, Path) and not path.exists():
raise FileNotFoundError(f"File not found: '{path}'")
match path.suffix:
case ".csv" | ".tsv" | ".txt" | ".tab" | ".sf":
kwargs.setdefault("sep", "," if path.suffix == ".csv" else "\t") # set sep if not defined
return pd.read_csv(path, **kwargs)
case ".xlsx" | ".xls":
return pd.read_excel(path, **kwargs)
case ".h5ad":
adata: sc.AnnData = sc.read_h5ad(path, **kwargs)
if h5ad_as_df:
df = adata.to_df().T
if not df.index.name:
df.index.name = "gene_symbol"
df.reset_index(inplace=True)
return df
return adata
case _:
raise ValueError(
f"Unknown file extension '{path.suffix}'. Valid options are '.tsv', '.csv', '.xlsx', '.xls', or '.h5ad'"
)
@overload
def _listify(value: list[T]) -> list[T]: ...
@overload
def _listify(value: T) -> list[T]: ...
def _listify(value: T | list[T]) -> list[T]:
"""Convert items into a list.
Args:
value: The value to convert to a list (if it isn't already)
Returns:
A list of the provided value
"""
return value if isinstance(value, list) else [value]
def num_rows(item: pd.DataFrame | npt.NDArray) -> int:
"""Return the number of rows in an object.
:param item: The object to check the number of rows of
:return: The number of rows in the provided object
"""
return item.shape[0]
def num_columns(item: pd.DataFrame | npt.NDArray) -> int:
"""Return the number of columns in an object.
:param item: The object to check the number of columns of
:return: The number of columns in the provided object
"""
return item.shape[1]
def return_placeholder_data() -> pd.DataFrame:
return pd.DataFrame(data=0, index=pd.Index(data=[0], name="entrez_gene_id"), columns=["expressed", "top"])
def set_up_logging(
level: LogLevel | str,
location: str | TextIO,
formatting: str = LOG_FORMAT,
):
"""Set up logging for the application.
:param level: The default logging level to use (e.g., LogLevel.INFO, LogLevel.DEBUG, etc.)
:param location: The location to log to (e.g., a file path or sys.stdout)
:param formatting: The log message format to use (default is LOG_FORMAT)
"""
if isinstance(level, str):
level = LogLevel[level.upper()]
with contextlib.suppress(ValueError):
logger.remove(0)
logger.add(sink=location, level=level.value, format=formatting)
def asyncable(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
"""Converts a synchronous function to asynchronous.
This wrapper functions by running the synchronous function in a separate thread using asyncio's run_in_executor
This allows the synchronous function to be called in an asynchronous context without blocking the event loop.
:param func: The synchronous function to convert.
:return: An asynchronous version of the input function that runs in a separate thread.
"""
@functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
loop = asyncio.get_running_loop()
call = functools.partial(func, *args, **kwargs)
return await loop.run_in_executor(None, call)
# return await loop.run_in_executor(None, lambda: func(*args, **kwargs))
return wrapper