Skip to content

Commit fbf36fc

Browse files
committed
Add ALIGNN work to separate branch
1 parent 4171bb2 commit fbf36fc

5 files changed

Lines changed: 789 additions & 3 deletions

File tree

matdeeplearn/common/metrics.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import matplotlib.pyplot as plt
2+
import pandas as pd
3+
import os
4+
from datetime import datetime
5+
6+
root_path = '/global/cfs/projectdirs/m3641/Sidharth/MatDeepLearn_dev/output/'
7+
8+
class MetricMonitor:
9+
'''
10+
Monitor and update various training metrics.
11+
'''
12+
def __init__(self, plotpath=os.path.join(root_path, 'plots'), datapath=os.path.join(root_path, 'data'), epoch_step=5, start_epoch=0) -> None:
13+
self.streams = {}
14+
self.epoch = {}
15+
self.start_epoch = start_epoch
16+
self.epoch_step = epoch_step
17+
self.plotpath = plotpath
18+
self.datapath = datapath
19+
20+
def create_data_stream(self, stream_name):
21+
'''
22+
Create a metric data stream to be updated.
23+
'''
24+
self.streams[stream_name] = []
25+
self.epoch[stream_name] = [self.start_epoch]
26+
27+
def update(self, stream_name, val):
28+
'''
29+
Update a data stream with a specified value.
30+
'''
31+
if stream_name in self.streams:
32+
self.streams[stream_name].append(val)
33+
self.epoch[stream_name].append(self.epoch[stream_name][len(self.epoch[stream_name]) - 1] + self.epoch_step)
34+
35+
def save_outputs(self):
36+
'''
37+
Save post-training metric outputs.
38+
'''
39+
timestamp = datetime.now()
40+
41+
metric_df = pd.DataFrame(self.streams)
42+
print(metric_df)
43+
print(os.path.join(self.datapath, f'train_metric_{timestamp}.csv'))
44+
metric_df.to_csv(os.path.join(self.datapath, f'train_metric_{timestamp}.csv'))
45+
46+
fig, axs = plt.subplots(len(self.streams.keys()))
47+
fig.suptitle(f'Training metrics {timestamp}')
48+
49+
for ax, item in zip(axs, self.streams.items()):
50+
ax.plot(self.epoch[item[0]][:-1], item[1])
51+
ax.set_title(item[0])
52+
53+
print(os.path.join(self.plotpath, f'plot_metrics_{timestamp}.png'))
54+
plt.savefig(os.path.join(self.plotpath, f'plot_metrics_{timestamp}.png'))
55+
56+
class DatasetMetrics:
57+
'''
58+
Analyze a graph dataset for basic properties
59+
and create basic visualization of overall statistics.
60+
'''
61+
def __init__(self) -> None:
62+
pass
63+
64+
class VisualizeGraph:
65+
'''
66+
Visualize input and latent space graphs with heatmap plots.
67+
TODO: Look at old MatDeepLearn to port over latent visualization code.
68+
'''
69+
def __init__(self) -> None:
70+
pass
71+
72+
# Testing code
73+
if __name__ == '__main__':
74+
print('Testing metrics')
75+
76+
m = MetricMonitor()
77+
m.create_data_stream('test1')
78+
m.create_data_stream('test2')
79+
80+
for i in range(100):
81+
m.update('test1', i)
82+
m.update('test2', i)
83+
84+
m.save_outputs()

0 commit comments

Comments
 (0)