-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path4_dash_heatmap.py
More file actions
39 lines (31 loc) · 1.08 KB
/
4_dash_heatmap.py
File metadata and controls
39 lines (31 loc) · 1.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import yaml
import numpy as np
from source.dashplot import ParticleTransformerHeatmap, HeatmapDashboard
with open('config.yaml', 'r') as file:
config = yaml.safe_load(file)
# Data configuration.
rnd_seed = config['rnd_seed']
dataset = config['dataset']
model = config['model']
# Load the summary file (the intermediate outputs from inference).
summary_path = os.path.join('intermediate_outputs', f"{dataset}_{model}_{rnd_seed}.npy")
summary = np.load(summary_path, allow_pickle=True).item()
# Create the particle transformer heatmap object.
particle_features = summary['particle_features']
intermediate_outputs = summary['intermediate_outputs']
heatmap = ParticleTransformerHeatmap(
particle_features=particle_features,
intermediate_outputs=intermediate_outputs,
linear_weights=summary['linear_weights'],
)
# Create the Dash app.
app = HeatmapDashboard(
channels=summary['channels'],
num_data=summary['num_data'],
num_epochs=summary['num_epochs'],
heatmap=heatmap,
io_buttons=[[f"Block {i+1}" for i in range(8)]]
)
if __name__ == '__main__':
app.run()