This repository was archived by the owner on Nov 15, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgeneral.py
More file actions
173 lines (139 loc) · 6.7 KB
/
general.py
File metadata and controls
173 lines (139 loc) · 6.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# Generic Imports
import re
from datetime import datetime
from itertools import product as cartesian_product
from functools import reduce
from operator import mul
from copy import deepcopy
from pathlib import Path
# Typing and Subclassing
from typing import Any, Callable, Iterable, Optional, Union
from dataclasses import dataclass
# Units
from pint import Quantity as PintQuantity
from openmm.unit.quantity import Quantity as OMMQuantity
# Math
greek_letter_names = [ # names for greek character literals
'alpha', 'beta', 'gamma', 'delta', 'epsilon', 'zeta', 'eta', 'theta',
'iota', 'kappa', 'lambda', 'mu', 'nu', 'xi', 'omicron', 'pi',
'rho', 'sigma_end', 'sigma', 'tau', 'upsilon', 'phi', 'chi', 'psi', 'omega'
]
_greek_start_idxs = { # indices where each case of the Greek alphabet starts in Unicode
'LOWER' : 945,
'UPPER' : 913
}
for case, idx in _greek_start_idxs.items():
globals()[f'GREEK_{case}'] = { # add dicts to global namespace
letter_name : chr(idx + i)
for i, letter_name in enumerate(greek_letter_names)
}
def product(container : Iterable):
'''Analogous to builtin sum()'''
return reduce(mul, container)
@dataclass
class Accumulator:
'''Compact container for accumulating averages'''
sum : float = 0.0
count : int = 0
@property
def average(self) -> float:
return self.sum / self.count
# Functional modifiers and decorators
def generate_repr(cls : Any) -> Any:
'''Class decorator for autogenerating __repr__ for attributes specified in DISP_ATTRS class attr
The class this is applied to MUST have implemented an iterable DISP_ATTRS class attribute'''
assert(hasattr(cls, 'DISP_ATTRS'))
disp_attrs : Iterable[str] = cls.DISP_ATTRS
def _repr_generic_(self) -> str:
attr_str = ', '.join(f'{attr}={getattr(self, attr)}' for attr in disp_attrs)
return f'{cls.__name__}({attr_str})'
cls.__repr__ : Callable[[Any], str] = _repr_generic_
return cls
def optional_in_place(funct : Callable[[Any, Any], None]) -> Callable[[Any, Any], Optional[Any]]:
'''Decorator function for allowing in-place (writeable) functions which modify object attributes
to be not performed in-place (i.e. read-only), specified by a boolean flag'''
def in_place_wrapper(obj : Any, *args, in_place : bool=False, **kwargs) -> Optional[Any]: # read-only by default
'''If not in-place, create a clone on which the method is executed'''
if in_place:
funct(obj, *args, **kwargs) # default call to writeable method - implicitly returns None
else:
copy_obj = deepcopy(obj) # clone object to avoid modifying original
funct(copy_obj, *args, **kwargs)
return copy_obj # return the new object
return in_place_wrapper
def asiterable(arg_val : Union[Any, Iterable[Any]]) -> Iterable[Any]:
'''Permits functions expecting iterable arguments to accept singular values'''
if not isinstance(arg_val, Iterable):
arg_val = (arg_val,) # turn into single-item tuple (better for memory)
return arg_val
def aspath(path : Union[Path, str]) -> Path:
'''Allow functions which expect Paths to also accept strings'''
if not isinstance(path, Path):
path = Path(path)
return path
def asstrpath(strpath : Union[str, Path]) -> str:
'''Allow functions which expect strings paths to also accept Paths'''
if not isinstance(strpath, str):
strpath = str(strpath)
return strpath
# Tools for iteration
def swappable_loop_order(iter1 : Iterable, iter2 : Iterable, swap : bool=False) -> Iterable[tuple[Any, Any]]:
'''Enables dynamic swapping of the order of execution of a 2-nested for loop'''
order = [iter1, iter2] if not swap else [iter2, iter1]
for pair in cartesian_product(*order):
yield pair[::(-1)**swap] # reverse order of pair (preserves argument identity)
def progress_iter(itera : Iterable, key : Callable[[Any], str]=lambda x : x) -> Iterable[tuple[str, Any]]:
'''Iterate through'''
N = len(itera) # TODO : extend this to work for generators / consumables
for i, item in enumerate(itera):
yield (f'{key(item)} ({i + 1} / {N})', item) # +1 converts to more human-readable 1-index for step count
# Data containers / data structures
@optional_in_place
def modify_dict(path_dict : dict[Any, Any], modifier_fn : Callable[[Any, Any], tuple[Any, bool]]) -> None:
'''Recursively modifies all values in a dict in-place according to some function'''
for key, val in path_dict.items():
if isinstance(val, dict): # recursive call if sub-values are also dicts with Paths
modify_dict(val, modifier_fn)
else:
path_dict[key] = modifier_fn(key, val)
def iter_len(itera : Iterable):
'''
Get size of an iterable object where ordinary len() call is invalid (namely a generator)
Note that this will "use up" a generator upon iteration
'''
return sum(1 for _ in itera)
def sort_dict_by_values(targ_dict : dict, reverse : bool=False) -> dict[Any, Any]:
'''Sort a dictionary according to the values of each key'''
return { # sort dict in ascending order by size
key : targ_dict[key]
for key in sorted(targ_dict, key=lambda k : targ_dict[k], reverse=reverse)
}
# Unit handling
class MissingUnitsError(Exception):
pass
def hasunits(obj : Any) -> bool:
'''Naive but effective way of checking for pint and openmm units'''
return any(hasattr(obj, attr) for attr in ('unit', 'units'))
def strip_units(coords : Union[tuple, PintQuantity, OMMQuantity]) -> tuple[float]:
'''
Sanitize coordinate tuples for cases which require unitless quantities
Specifically needed since OpenMM and pint each have their own Quantity and Units classes
'''
if isinstance(coords, PintQuantity):
return coords.magnitude
elif isinstance(coords, OMMQuantity):
return coords._value
return coords
# Date / time formatting
@dataclass
class Timestamp:
'''For storing information on date processing'''
fmt_str : str = '%m-%d-%Y_at_%H-%M-%S_%p'# should be formatted such that the resulting string can be safely used in a filename (i.e. no slashes)
regex : Union[str, re.Pattern] = re.compile(r'\d{2}-\d{2}-\d{4}_at_\d{2}-\d{2}-\d{2}_\w{2}')
def timestamp_now(self) -> str:
'''Return a string timestamped with the current date and time (at the time of calling)'''
return datetime.now().strftime(self.fmt_str)
def extract_datetime(self, timestr : str) -> datetime:
'''De-format a string containing a timestamp and extract just the timestamp as a datetime object'''
timestamps = re.search(self.regex, timestr) # pull out JUST the datetime formatting component
return datetime.strptime(timestamps.group(), self.fmt_str) # convert to datetime object