@@ -26,7 +26,7 @@ use tokio::{
2626 io:: ReadBuf ,
2727 net:: UdpSocket ,
2828 sync:: mpsc,
29- time:: { sleep, timeout} ,
29+ time:: { interval_at , sleep, timeout, Instant , Interval , MissedTickBehavior } ,
3030 try_join,
3131} ;
3232
@@ -184,6 +184,14 @@ async fn create_tunnel(
184184 . map_err ( UdpTunnelError :: AllocateSocketPool ) ?;
185185 debug ! ( "Allocated tunnel pool" ) ;
186186
187+ let now = Instant :: now ( ) ;
188+
189+ // Create the interval to track keep alive checking
190+ let keep_alive_start = now + KEEP_ALIVE_DELAY ;
191+ let mut keep_alive_interval = interval_at ( keep_alive_start, KEEP_ALIVE_DELAY ) ;
192+
193+ keep_alive_interval. set_missed_tick_behavior ( MissedTickBehavior :: Delay ) ;
194+
187195 // Start the tunnel
188196 Tunnel {
189197 socket,
@@ -192,11 +200,11 @@ async fn create_tunnel(
192200 pool,
193201 write_state : Default :: default ( ) ,
194202 read_buffer : [ 0u8 ; u16:: MAX as usize ] ,
203+ last_keep_alive : now,
204+ keep_alive_interval,
195205 }
196206 . await ;
197207
198- // TODO: Handle connection lost
199-
200208 Ok ( ( ) )
201209}
202210
@@ -274,6 +282,13 @@ async fn handshake_tunnel(socket: &UdpSocket, association: &str) -> Result<u32,
274282 }
275283}
276284
285+ /// Delay between each keep-alive check
286+ const KEEP_ALIVE_DELAY : Duration = Duration :: from_secs ( 10 ) ;
287+
288+ /// When this duration elapses between keep-alive checks for a connection
289+ /// the connection is considered to be dead (4 missed keep-alive check intervals)
290+ const KEEP_ALIVE_TIMEOUT : Duration = Duration :: from_secs ( KEEP_ALIVE_DELAY . as_secs ( ) * 4 ) ;
291+
277292/// Represents a tunnel and its pool of connections that it can
278293/// send data to and receive data from
279294struct Tunnel {
@@ -291,6 +306,10 @@ struct Tunnel {
291306 write_state : TunnelWriteState ,
292307 /// Buffer for reading
293308 read_buffer : [ u8 ; u16:: MAX as usize ] ,
309+ /// Last time a keep-alive message was received through the tunnel
310+ last_keep_alive : Instant ,
311+ /// Interval for polling connection alive checks
312+ keep_alive_interval : Interval ,
294313}
295314
296315/// Holds the state for the current writing progress for a [`Tunnel`]
@@ -365,6 +384,20 @@ impl Tunnel {
365384 ///
366385 /// Should be repeatedly called until it no-longer returns [`Poll::Ready`]
367386 fn poll_read_state ( & mut self , cx : & mut Context < ' _ > ) -> Poll < TunnelReadState > {
387+ // Poll for keep alive
388+ if self . keep_alive_interval . poll_tick ( cx) . is_ready ( ) {
389+ debug ! ( "checking connection alive" ) ;
390+
391+ let now = Instant :: now ( ) ;
392+
393+ let last_alive = self . last_keep_alive . duration_since ( now) ;
394+ if last_alive > KEEP_ALIVE_TIMEOUT {
395+ // Connection to the server has timed out as no keep alive messages were
396+ // given by the server
397+ return Poll :: Ready ( TunnelReadState :: Stop ) ;
398+ }
399+ }
400+
368401 // Try receive a message from the `io`
369402 if ready ! ( Pin :: new( & mut self . socket) . poll_recv_ready( cx) ) . is_err ( ) {
370403 // Cannot read next message stop the tunnel
@@ -403,6 +436,8 @@ impl Tunnel {
403436
404437 // Reply to keep-alive message
405438 TunnelMessage :: KeepAlive => {
439+ self . last_keep_alive = Instant :: now ( ) ;
440+
406441 self . write_state = TunnelWriteState :: Write ( Some ( TunnelMessage :: KeepAlive ) ) ;
407442
408443 // Poll the write state
0 commit comments