11import tempfile
2+ from pathlib import Path
23
34import numpy as np
45import pytest
@@ -172,13 +173,17 @@ def test_rnn_checkpoint_save_load(simple_message):
172173 # First pass to initialize model
173174 proc (simple_message )
174175
175- # Save full checkpoint (state_dict + config)
176- with tempfile .NamedTemporaryFile (suffix = ".pt" ) as tmp :
177- proc .save_checkpoint (tmp .name )
176+ # Create a temporary file that is closed immediately
177+ with tempfile .NamedTemporaryFile (suffix = ".pt" , delete = False ) as tmp :
178+ checkpoint_path = Path (tmp .name )
179+
180+ try :
181+ # Save full checkpoint (state_dict + config)
182+ proc .save_checkpoint (str (checkpoint_path ))
178183
179184 # Load from checkpoint
180185 proc2 = RNNProcessor (
181- checkpoint_path = tmp . name ,
186+ checkpoint_path = str ( checkpoint_path ) ,
182187 single_precision = single_precision ,
183188 device = "cpu" ,
184189 model_kwargs = {
@@ -200,6 +205,10 @@ def test_rnn_checkpoint_save_load(simple_message):
200205 f"Mismatch in parameter { key } "
201206 )
202207
208+ finally :
209+ # Ensure the temporary file is deleted
210+ checkpoint_path .unlink (missing_ok = True )
211+
203212
204213def test_rnn_partial_fit_multiloss (simple_message ):
205214 hidden_size = 16
@@ -322,7 +331,7 @@ def test_rnn_preserve_state_batch_size_change():
322331 single_precision = True ,
323332 device = "cpu" ,
324333 preserve_state_across_windows = True ,
325- model_kwargs = {"hidden_size" : hidden_size , "output_heads " : output_size },
334+ model_kwargs = {"hidden_size" : hidden_size , "output_size " : output_size },
326335 )
327336
328337 # First message: 1 window
0 commit comments