Skip to content

Commit cc2a023

Browse files
committed
Fix GHA Windows runner error about tempfile lock.
1 parent ce8e69f commit cc2a023

2 files changed

Lines changed: 27 additions & 9 deletions

File tree

tests/unit/test_rnn.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import tempfile
2+
from pathlib import Path
23

34
import numpy as np
45
import pytest
@@ -172,13 +173,17 @@ def test_rnn_checkpoint_save_load(simple_message):
172173
# First pass to initialize model
173174
proc(simple_message)
174175

175-
# Save full checkpoint (state_dict + config)
176-
with tempfile.NamedTemporaryFile(suffix=".pt") as tmp:
177-
proc.save_checkpoint(tmp.name)
176+
# Create a temporary file that is closed immediately
177+
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as tmp:
178+
checkpoint_path = Path(tmp.name)
179+
180+
try:
181+
# Save full checkpoint (state_dict + config)
182+
proc.save_checkpoint(str(checkpoint_path))
178183

179184
# Load from checkpoint
180185
proc2 = RNNProcessor(
181-
checkpoint_path=tmp.name,
186+
checkpoint_path=str(checkpoint_path),
182187
single_precision=single_precision,
183188
device="cpu",
184189
model_kwargs={
@@ -200,6 +205,10 @@ def test_rnn_checkpoint_save_load(simple_message):
200205
f"Mismatch in parameter {key}"
201206
)
202207

208+
finally:
209+
# Ensure the temporary file is deleted
210+
checkpoint_path.unlink(missing_ok=True)
211+
203212

204213
def test_rnn_partial_fit_multiloss(simple_message):
205214
hidden_size = 16
@@ -322,7 +331,7 @@ def test_rnn_preserve_state_batch_size_change():
322331
single_precision=True,
323332
device="cpu",
324333
preserve_state_across_windows=True,
325-
model_kwargs={"hidden_size": hidden_size, "output_heads": output_size},
334+
model_kwargs={"hidden_size": hidden_size, "output_size": output_size},
326335
)
327336

328337
# First message: 1 window

tests/unit/test_transformer.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import tempfile
2+
from pathlib import Path
23

34
import numpy as np
45
import pytest
@@ -172,13 +173,17 @@ def test_transformer_checkpoint_save_load(simple_message):
172173
# First pass to initialize model
173174
proc(simple_message)
174175

175-
# Save full checkpoint (state_dict + config)
176-
with tempfile.NamedTemporaryFile(suffix=".pt") as tmp:
177-
proc.save_checkpoint(tmp.name)
176+
# Create a temporary file that is closed immediately
177+
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as tmp:
178+
checkpoint_path = Path(tmp.name)
179+
180+
try:
181+
# Save full checkpoint (state_dict + config)
182+
proc.save_checkpoint(str(checkpoint_path))
178183

179184
# Load from checkpoint
180185
proc2 = TransformerProcessor(
181-
checkpoint_path=tmp.name,
186+
checkpoint_path=str(checkpoint_path),
182187
single_precision=single_precision,
183188
device="cpu",
184189
model_kwargs={
@@ -200,6 +205,10 @@ def test_transformer_checkpoint_save_load(simple_message):
200205
f"Mismatch in parameter {key}"
201206
)
202207

208+
finally:
209+
# Ensure the temporary file is deleted
210+
checkpoint_path.unlink(missing_ok=True)
211+
203212

204213
def test_transformer_partial_fit_multiloss(simple_message):
205214
hidden_size = 16

0 commit comments

Comments
 (0)