-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Expand file tree
/
Copy pathtest_aggregation.py
More file actions
42 lines (32 loc) · 1.42 KB
/
test_aggregation.py
File metadata and controls
42 lines (32 loc) · 1.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import unittest
from aggregation import StableMean
class TestStableMean(unittest.TestCase):
def setUp(self):
self.metric = StableMean()
def test_compute_empty(self):
result = self.metric.compute()
self.assertEqual(result, torch.tensor(0.0))
def test_compute_single_value(self):
self.metric.update(torch.tensor(1.0))
result = self.metric.compute()
self.assertEqual(result, torch.tensor(1.0))
def test_compute_weighted_single_value(self):
self.metric.update(torch.tensor(1.0), weight=torch.tensor(2.0))
result = self.metric.compute()
self.assertEqual(result, torch.tensor(1.0))
def test_compute_multiple_values(self):
self.metric.update(torch.tensor(1.0))
self.metric.update(torch.tensor(2.0))
self.metric.update(torch.tensor(3.0))
result = self.metric.compute()
self.assertEqual(result, torch.tensor(2.0))
def test_compute_weighted_multiple_values(self):
self.metric.update(torch.tensor(1.0), weight=torch.tensor(1.0))
self.metric.update(torch.tensor(2.0), weight=torch.tensor(2.0))
self.metric.update(torch.tensor(3.0), weight=torch.tensor(3.0))
result = self.metric.compute()
print(f"get= {result.item()} but expected= 2.1666666667")
self.assertAlmostEqual(result.item(), 2.1666666667, places=0)
if '__name__' == '__main__':
unittest.main()