-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathbasic_stats.py
More file actions
259 lines (230 loc) · 8.8 KB
/
basic_stats.py
File metadata and controls
259 lines (230 loc) · 8.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
"""
Basic Stats module.
"""
from collections.abc import Iterable
import dataclasses
import os
import re
import torch
import numpy as np
from pymarlin.utils.logger.logging_utils import getlogger
try:
import psutil
except ImportError:
pass
@dataclasses.dataclass
class StatInitArguments:
"""
Stats Arguments.
Args:
log_steps(int): Interval of logging. If log_steps is 50 and metric X is updated every step, then only 50th step value will be logged. Defaults to 1
update_system_stats(bool): Logs system stats like CPU, RAM, GPU usage when enabled. Defaults to False
log_model_steps(int): Interval to log model weight norm and grad norm. Defaults to 1000
exclude_list(str): Regular expression which when matched with parameter name, won't print weight norm and grad norms for that parameter.
Defaults to r"bias|LayerNorm|layer\.[3-9]|layer\.1(?!1)|layer\.2(?!3)"
"""
log_steps: int = 1
update_system_stats: bool = False
log_model_steps: int = 1000
exclude_list: str = r"bias|LayerNorm|layer\.[3-9]|layer\.1(?!1)|layer\.2(?!3)"
class BasicStats:
"""
Basis Stats class provides a common place for collects long interval stats and step interval
stats that can be recorded in the various writers provided at the time of calling rebuild()
in trainer. This class is used as a Singleton pattern via global_stats provided in the
__init__.py file.
"""
def __init__(self, args: StatInitArguments, writers=None):
self.args = args
self.logger = getlogger(__name__)
self.reset()
self.writers = writers
def rebuild(self, args: StatInitArguments, writers: Iterable):
"""
Rebuild Stat Args and Writers.
"""
self.args = args
self.writers = writers
def reset(self):
"""
Reset all stats.
"""
self.reset_short()
self.reset_long()
def reset_short(self):
"""
Reset step interval stats.
"""
self.scalars_short = {}
self.multi_short = {}
def reset_long(self):
"""
Reset long interval stats.
"""
self.scalars_long = {}
self.multi_long = {}
self.images = {}
self.pr = {} # key->(pred, labels)
self.histogram = {} # key -> vals
self.embedding = {}
def update(self, k, v, frequent=False):
"""
Update step interval and long interval scalar stats.
"""
if frequent:
self.scalars_short[k] = v
else:
self.scalars_long[k] = v
def update_multi(self, k, v: dict, frequent=False):
"""
Update step interval and long interval multiple scalar stats.
"""
if frequent:
self.multi_short[k] = v
else:
self.multi_long[k] = v
def update_matplotlib_figure(self, fig, tag):
"""
Update matplotlib figure.
"""
try:
from PIL import Image
except ImportError:
self.logger.info("Can't import PIL, can't update matplotlib figure")
return
import io
img = None
with io.BytesIO() as output:
fig.savefig(output, format="PNG")
output.seek(0)
contents = output.getvalue()
img = Image.open(io.BytesIO(contents))
image_arr = np.array(img)
self.update_image(tag, image_arr, 'HWC')
def update_image(self, k, v, dataformats='HW'):
"""
Update image.
Will be logged with infrequent metric.
"""
self.images[k] = (v, dataformats)
def update_pr(self, k, preds, labels):
"""
Update pr curve stats.
Only binary classification
preds = probabilities
"""
self.pr[k] = (preds, labels)
def update_histogram(self, k, vals, extend=False):
"""
Update histogram stats.
"""
if not extend or k not in self.histogram:
self.histogram[k] = vals
else:
self.histogram[k] = torch.cat((self.histogram[k], vals))
def update_embedding(self, k, embs, labels):
"""
Update embeddings.
Used to project embeddings with corresponding labels (numerical).
"""
self.embedding[k] = (embs, labels)
def log_stats(self, step, force=False):
if (step % self.args.log_steps == 0) or force:
self.logger.debug(f'logging short stats for step {step}')
if self.args.update_system_stats:
self.update_system_stats()
for writer in self.writers:
for k, v in self.scalars_short.items():
writer.log_scalar(k, v, step)
for k, v in self.multi_short.items():
writer.log_multi(k, v, step)
#writer.flush()
#print(self.scalars_short)
self.reset_short()
def update_system_stats(self):
"""
Update system stats related to Memory and Compute (CPU and GPUs) usage.
"""
try:
process = psutil.Process(os.getpid())
#RAM
self.update('system/RAM/memory_used_pct', psutil.virtual_memory().percent, frequent=True)
self.update('system/RAM/memory_elr_pct', process.memory_info().rss/psutil.virtual_memory().total *100, \
frequent=True)
#CPU
self.update('system/CPU/pct', psutil.cpu_percent(interval=1), frequent=True)
#GPU
if self.args.device.type != 'cpu':
self.update('system/GPU0/memory_used_pct', \
torch.cuda.memory_allocated(device=self.args.device) / \
torch.cuda.get_device_properties(self.args.device).total_memory *100, frequent=True)
except Exception as e: # pylint: disable=broad-except
self.logger.warning(f'error in update_system_stats : {e}')
def log_long_stats(self, step):
"""
Log long interval stats to correponding writers.
"""
self.logger.debug(f'logging long stats for step {step}')
for writer in self.writers:
for k, v in self.scalars_long.items():
writer.log_scalar(k, v, step)
for k, v in self.multi_long.items():
writer.log_multi(k, v, step)
for k, v in self.images.items():
writer.log_image(k, v[0], step, dataformats=v[1])
for k, v in self.pr.items():
writer.log_pr_curve(k, preds=v[0], labels=v[1], step=step)
for k, v in self.histogram.items():
writer.log_histogram(k, v, step)
for k, v in self.embedding.items():
writer.log_embedding(tag=k, mat=v[0], labels=v[1], step=step)
self.reset_long()
def log_args(self, args):
"""
Log Arguments to correponding writers.
"""
self.logger.debug('Logging args to file.')
for writer in self.writers:
writer.log_args(args)
def log_model(self, step, model, force=False, grad_scale=1):
"""
Log model to correponding writers.
"""
self.logger.debug('basic - beginning log_model')
if (step % self.args.log_model_steps == 0) or force:
self.logger.info(f'force {force}, logging model stats for step {step}')
flat_weights, flat_grads = self._get_flat_param_vals(model, grad_scale)
for writer in self.writers:
self.logger.debug(f'basic - log_model - beginning writer.log_model {type(writer)}')
writer.log_model(flat_weights, flat_grads, step)
self.logger.debug(f'basic - log_model - finishing writer.log_model {type(writer)}')
self.logger.debug('basic - finishing log_model')
def log_graph(self, model, device):
"""
Log graph to correponding writers.
"""
self.logger.debug('logging graph')
for writer in self.writers:
writer.log_graph(model, device=device)
def finish(self):
"""
Call finish() on all writers.
"""
for writer in self.writers:
writer.finish()
def _get_flat_param_vals(self, model, grad_scale):
self.logger.debug('basic - beginning _get_flat_param_vals')
flat_weights = {}
flat_grads = {}
for name, param in model.named_parameters():
if self._exclude_param(name):
continue
flat_weights[name] = param.data.view(-1)
if param.grad is not None:
flat_grads[name] = param.grad.view(-1)/grad_scale
self.logger.debug('basic - finishing _get_flat_param_vals')
return flat_weights, flat_grads
def _exclude_param(self, param_name):
if re.search(self.args.exclude_list, param_name):
return True
return False