22
33use std:: collections:: HashMap ;
44use std:: sync:: Arc ;
5+ use std:: time:: { Duration , Instant } ;
56
67use indexmap:: IndexMap ;
78use serde:: Deserialize ;
@@ -10,7 +11,7 @@ use sqlx::{Column, Row};
1011use sqlx_sqlite_conn_mgr:: { AttachedWriteGuard , WriteGuard } ;
1112use tokio:: sync:: { Mutex , RwLock } ;
1213use tokio:: task:: AbortHandle ;
13- use tracing:: debug;
14+ use tracing:: { debug, warn } ;
1415
1516#[ cfg( feature = "observer" ) ]
1617use sqlx_sqlite_observer:: ObservableWriteGuard ;
@@ -97,6 +98,7 @@ pub struct ActiveInterruptibleTransaction {
9798 db_path : String ,
9899 transaction_id : String ,
99100 writer : Option < TransactionWriter > ,
101+ created_at : Instant ,
100102}
101103
102104impl ActiveInterruptibleTransaction {
@@ -105,6 +107,7 @@ impl ActiveInterruptibleTransaction {
105107 db_path,
106108 transaction_id,
107109 writer : Some ( writer) ,
110+ created_at : Instant :: now ( ) ,
108111 }
109112 }
110113
@@ -241,34 +244,71 @@ impl Drop for ActiveInterruptibleTransaction {
241244 }
242245}
243246
247+ /// Default transaction timeout (5 minutes).
248+ const DEFAULT_TRANSACTION_TIMEOUT : Duration = Duration :: from_secs ( 300 ) ;
249+
244250/// Global state tracking all active interruptible transactions.
245251///
246- /// Enforces one interruptible transaction per database path.
252+ /// Enforces one interruptible transaction per database path and applies a configurable
253+ /// timeout. Expired transactions are cleaned up lazily on the next `insert()` or
254+ /// `remove()` call — no background task is needed.
255+ ///
247256/// Uses `Mutex` rather than `RwLock` because all operations require write access,
248257/// and `Mutex<T>` only requires `T: Send` (not `T: Sync`) — avoiding an
249258/// `unsafe impl Sync` that would otherwise be needed due to non-`Sync` inner
250259/// types (`PoolConnection`, raw pointers in observer guards).
251- #[ derive( Clone , Default ) ]
252- pub struct ActiveInterruptibleTransactions (
253- Arc < Mutex < HashMap < String , ActiveInterruptibleTransaction > > > ,
254- ) ;
260+ #[ derive( Clone ) ]
261+ pub struct ActiveInterruptibleTransactions {
262+ inner : Arc < Mutex < HashMap < String , ActiveInterruptibleTransaction > > > ,
263+ timeout : Duration ,
264+ }
265+
266+ impl Default for ActiveInterruptibleTransactions {
267+ fn default ( ) -> Self {
268+ Self :: new ( DEFAULT_TRANSACTION_TIMEOUT )
269+ }
270+ }
255271
256272impl ActiveInterruptibleTransactions {
273+ /// Create a new instance with the given transaction timeout.
274+ pub fn new ( timeout : Duration ) -> Self {
275+ Self {
276+ inner : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
277+ timeout,
278+ }
279+ }
280+
257281 pub async fn insert ( & self , db_path : String , tx : ActiveInterruptibleTransaction ) -> Result < ( ) > {
258282 use std:: collections:: hash_map:: Entry ;
259- let mut txs = self . 0 . lock ( ) . await ;
283+ let mut txs = self . inner . lock ( ) . await ;
260284
261285 match txs. entry ( db_path. clone ( ) ) {
262286 Entry :: Vacant ( e) => {
263287 e. insert ( tx) ;
264288 Ok ( ( ) )
265289 }
266- Entry :: Occupied ( _) => Err ( Error :: TransactionAlreadyActive ( db_path) ) ,
290+ Entry :: Occupied ( mut e) => {
291+ // If the existing transaction has expired, drop it (auto-rollback) and
292+ // replace with the new one.
293+ if e. get ( ) . created_at . elapsed ( ) >= self . timeout {
294+ warn ! (
295+ "Evicting expired transaction for db: {} (age: {:?}, timeout: {:?})" ,
296+ db_path,
297+ e. get( ) . created_at. elapsed( ) ,
298+ self . timeout,
299+ ) ;
300+ // Drop the expired transaction (auto-rollback) before inserting the new one
301+ let _expired = e. insert ( tx) ;
302+ Ok ( ( ) )
303+ } else {
304+ Err ( Error :: TransactionAlreadyActive ( db_path) )
305+ }
306+ }
267307 }
268308 }
269309
270310 pub async fn abort_all ( & self ) {
271- let mut txs = self . 0 . lock ( ) . await ;
311+ let mut txs = self . inner . lock ( ) . await ;
272312 debug ! ( "Aborting {} active interruptible transaction(s)" , txs. len( ) ) ;
273313
274314 for db_path in txs. keys ( ) {
@@ -283,13 +323,17 @@ impl ActiveInterruptibleTransactions {
283323 txs. clear ( ) ;
284324 }
285325
286- /// Remove and return transaction for commit/rollback
326+ /// Remove and return transaction for commit/rollback.
327+ ///
328+ /// Returns `Err(Error::TransactionTimedOut)` if the transaction has exceeded the
329+ /// configured timeout. The expired transaction is dropped (auto-rolled-back) in
330+ /// that case.
287331 pub async fn remove (
288332 & self ,
289333 db_path : & str ,
290334 token_id : & str ,
291335 ) -> Result < ActiveInterruptibleTransaction > {
292- let mut txs = self . 0 . lock ( ) . await ;
336+ let mut txs = self . inner . lock ( ) . await ;
293337
294338 // Validate token before removal
295339 let tx = txs
@@ -300,6 +344,19 @@ impl ActiveInterruptibleTransactions {
300344 return Err ( Error :: InvalidTransactionToken ) ;
301345 }
302346
347+ // Check if the transaction has expired
348+ if tx. created_at . elapsed ( ) >= self . timeout {
349+ warn ! (
350+ "Transaction timed out for db: {} (age: {:?}, timeout: {:?})" ,
351+ db_path,
352+ tx. created_at. elapsed( ) ,
353+ self . timeout,
354+ ) ;
355+ // Drop the expired transaction (auto-rollback via Drop)
356+ txs. remove ( db_path) ;
357+ return Err ( Error :: TransactionTimedOut ( db_path. to_string ( ) ) ) ;
358+ }
359+
303360 // Safe unwrap: we just confirmed the key exists above
304361 Ok ( txs. remove ( db_path) . unwrap ( ) )
305362 }
0 commit comments