Skip to content

Commit 34650a0

Browse files
committed
Clear mps memory during pytest
1 parent e84f6b8 commit 34650a0

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

tests/unit/test_torch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ def infer_config_from_state_dict(cls, state_dict):
5454
return {"input_size": state_dict["head_a.weight"].shape[1]}
5555

5656

57+
@pytest.fixture(autouse=True)
58+
def mps_memory_cleanup():
59+
"""Fixture to clean up MPS memory after each test."""
60+
yield
61+
if torch.backends.mps.is_available():
62+
torch.mps.empty_cache()
63+
64+
5765
@pytest.fixture
5866
def batch_message():
5967
input_dim = 6

0 commit comments

Comments
 (0)