-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathresult.py
More file actions
53 lines (44 loc) · 1.87 KB
/
result.py
File metadata and controls
53 lines (44 loc) · 1.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from dataclasses import dataclass, field
from typing import Optional
import torch
@dataclass
class Result:
model_name: str
dataset: str
sparsing_name: str
acc: Optional[list[float]] = field(default=None, compare=False)
percent2remove: Optional[float] = field(default=None, compare=False)
removed_percentage: Optional[float] = field(default=None, compare=False)
def __str__(self) -> str:
"""Provides a formatted string representation of the result."""
lines = [
f"Model Name: {self.model_name}",
f"Dataset: {self.dataset}",
f"Sparsing Name: {self.sparsing_name}",
f"percent2remove: {self.percent2remove:.2g}" if self.percent2remove is not None else "Percent2remove: None",
]
if self.acc:
lines.extend([
f"Accuracy Mean: {torch.tensor(self.acc).mean():.2%}",
f"Accuracy Std: {torch.tensor(self.acc).std():.2%}",
])
else:
lines.append("Accuracy: N/A")
if self.removed_percentage is not None:
lines.append(f"Removed %: {self.removed_percentage:.2%}")
else:
lines.append("Removed %: N/A")
lines.append("\n")
return "\n".join(lines)
def as_dict(self) -> dict[str, str]:
acc_mean = f'{torch.tensor(self.acc).mean():.2%}' if self.acc else 'N/A'
acc_std = f'{torch.tensor(self.acc).std():.2%}' if self.acc else 'N/A'
return {
'Model Name': self.model_name,
'Dataset': self.dataset,
'Sparsing Name': self.sparsing_name,
'Power': f'{self.percent2remove:.2g}' if self.percent2remove is not None else 'None',
'Accuracy Mean': acc_mean,
'Accuracy Std': acc_std,
'Removed %': f'{self.removed_percentage:.2%}' if self.removed_percentage is not None else 'N/A'
}