@@ -36,7 +36,7 @@ use crate::message::sequence_reset::SequenceReset;
3636use crate :: message:: test_request:: TestRequest ;
3737use crate :: message:: verification:: VerificationFlags ;
3838use crate :: session:: admin_request:: AdminRequest ;
39- use crate :: session:: ctx:: { SessionCtx , TransitionResult , VerificationResult } ;
39+ use crate :: session:: ctx:: { PreProcessDecision , SessionCtx , TransitionResult , VerificationResult } ;
4040use crate :: session:: error:: SessionCreationError ;
4141use crate :: session:: error:: { InternalSendError , InternalSendResultExt , SessionOperationError } ;
4242pub use crate :: session:: error:: { SendError , SendOutcome } ;
@@ -48,7 +48,7 @@ pub(crate) use crate::session::session_ref::InternalSessionRef;
4848pub use crate :: session:: session_ref:: InternalSessionRef ;
4949use crate :: session:: session_ref:: OutboundRequest ;
5050use crate :: session:: state:: SessionState ;
51- use crate :: session:: state:: { AwaitingLogonState , AwaitingLogoutState , TestRequestId } ;
51+ use crate :: session:: state:: { AwaitingLogonState , TestRequestId } ;
5252use crate :: session_schedule:: { SessionPeriodComparison , SessionSchedule } ;
5353use crate :: store:: MessageStore ;
5454use crate :: transport:: writer:: WriterRef ;
@@ -194,29 +194,20 @@ where
194194 }
195195
196196 async fn process_message ( & mut self , message : Message ) -> Result < ( ) , SessionOperationError > {
197+ let message = match self . state . pre_process_inbound ( message) {
198+ PreProcessDecision :: Accept ( msg) => msg,
199+ PreProcessDecision :: Queued => return Ok ( ( ) ) ,
200+ PreProcessDecision :: Disconnect => {
201+ self . state . disconnect_writer ( ) . await ;
202+ return Ok ( ( ) ) ;
203+ }
204+ } ;
205+
197206 let message_type: & str = message
198207 . header ( )
199208 . get ( MSG_TYPE )
200209 . map_err ( |_| SessionOperationError :: MissingField ( "MSG_TYPE" ) ) ?;
201210
202- if let SessionState :: AwaitingResend ( state) = & mut self . state {
203- let seq_number = get_msg_seq_num ( & message) ;
204- if seq_number > state. end_seq_number && message_type != ResendRequest :: MSG_TYPE {
205- state. inbound_queue . push_back ( message) ;
206- return Ok ( ( ) ) ;
207- }
208- }
209-
210- // TODO: add state-level pre-process check that validates whether the message type
211- // is acceptable in the current state (e.g. AwaitingLogon rejects non-Logon,
212- // unexpected Logon in Active should be rejected per FIX spec).
213- if let SessionState :: AwaitingLogon ( _) = & mut self . state
214- && message_type != Logon :: MSG_TYPE
215- {
216- self . state . disconnect_writer ( ) . await ;
217- return Ok ( ( ) ) ;
218- }
219-
220211 let flags = VerificationFlags :: for_message ( & message) ?;
221212 if let VerificationResult :: Issue ( result) = self
222213 . state
@@ -286,46 +277,42 @@ where
286277 }
287278
288279 async fn check_end_of_resend ( & mut self ) -> Result < ( ) , SessionOperationError > {
289- let backlog = if let SessionState :: AwaitingResend ( state) = & mut self . state {
290- if self . ctx . store . next_target_seq_number ( ) > state. end_seq_number {
291- let inbound_queue = std:: mem:: take ( & mut state. inbound_queue ) ;
292- let new_state = SessionState :: new_active (
293- state. writer . clone ( ) ,
294- self . ctx . config . heartbeat_interval ,
295- ) ;
296- self . apply_transition ( TransitionResult :: TransitionTo ( new_state) )
297- . await ;
298- Some ( inbound_queue)
299- } else {
300- None
301- }
280+ let completed = if let SessionState :: AwaitingResend ( state) = & mut self . state {
281+ state. try_complete (
282+ self . ctx . store . next_target_seq_number ( ) ,
283+ self . ctx . config . heartbeat_interval ,
284+ )
302285 } else {
303286 None
304287 } ;
305288
306- if let Some ( mut inbound_queue) = backlog {
307- // we have reached the end of the resend,
308- // process queued messages and resume normal operation
309- debug ! ( "resend is done, processing backlog" ) ;
310- while let Some ( msg) = inbound_queue. pop_front ( ) {
311- let seq_number: u64 = msg. get ( MSG_SEQ_NUM ) . unwrap_or_else ( |e| {
312- error ! ( "failed to get seq number: {:?}" , e) ;
313- 0
314- } ) ;
315- let msg_type: & str = msg. header ( ) . get ( MSG_TYPE ) . unwrap_or ( "" ) ;
316- debug ! ( seq_number, msg_type, "processing queued message" ) ;
317-
318- if msg_type == ResendRequest :: MSG_TYPE {
319- // ResendRequest was already processed when it arrived (it bypasses
320- // the queue in process_message). Just increment the target seq number
321- // for sequence accounting purposes.
322- self . ctx . store . increment_target_seq_number ( ) . await ?;
323- } else {
324- self . process_message ( msg) . await ?;
325- }
289+ let Some ( ( new_state, mut backlog) ) = completed else {
290+ return Ok ( ( ) ) ;
291+ } ;
292+
293+ self . apply_transition ( TransitionResult :: TransitionTo ( new_state) )
294+ . await ;
295+
296+ // Process queued messages and resume normal operation
297+ debug ! ( "resend is done, processing backlog" ) ;
298+ while let Some ( msg) = backlog. pop_front ( ) {
299+ let seq_number: u64 = msg. get ( MSG_SEQ_NUM ) . unwrap_or_else ( |e| {
300+ error ! ( "failed to get seq number: {:?}" , e) ;
301+ 0
302+ } ) ;
303+ let msg_type: & str = msg. header ( ) . get ( MSG_TYPE ) . unwrap_or ( "" ) ;
304+ debug ! ( seq_number, msg_type, "processing queued message" ) ;
305+
306+ if msg_type == ResendRequest :: MSG_TYPE {
307+ // ResendRequest was already processed when it arrived (it bypasses
308+ // the queue in process_message). Just increment the target seq number
309+ // for sequence accounting purposes.
310+ self . ctx . store . increment_target_seq_number ( ) . await ?;
311+ } else {
312+ self . process_message ( msg) . await ?;
326313 }
327- debug ! ( "resend backlog is cleared, resuming normal operation" ) ;
328314 }
315+ debug ! ( "resend backlog is cleared, resuming normal operation" ) ;
329316
330317 Ok ( ( ) )
331318 }
@@ -346,39 +333,14 @@ where
346333 }
347334
348335 async fn on_disconnect ( & mut self , reason : String ) {
349- let transition = match self . state {
350- SessionState :: Active ( _)
351- | SessionState :: AwaitingLogon ( _)
352- | SessionState :: AwaitingResend ( _) => {
353- self . state . disconnect_writer ( ) . await ;
354- TransitionResult :: TransitionTo ( SessionState :: new_disconnected ( true , & reason) )
355- }
356- SessionState :: Disconnected ( _) => {
357- warn ! ( "disconnect message was received, but the session is already disconnected" ) ;
358- TransitionResult :: Stay
359- }
360- SessionState :: AwaitingLogout ( AwaitingLogoutState { reconnect, .. } ) => {
361- TransitionResult :: TransitionTo ( SessionState :: new_disconnected ( reconnect, & reason) )
362- }
363- } ;
336+ self . state . disconnect_writer ( ) . await ;
337+ let transition = self . state . on_disconnect ( & reason) ;
364338 self . apply_transition ( transition) . await ;
365339 }
366340
367341 async fn on_logon ( & mut self ) -> Result < ( ) , SessionOperationError > {
368- if let SessionState :: AwaitingLogon ( AwaitingLogonState { writer, .. } ) = & self . state {
369- let writer = writer. clone ( ) ;
370- // happy logon flow, the session is now active
371- self . apply_transition ( TransitionResult :: TransitionTo ( SessionState :: new_active (
372- writer,
373- self . ctx . config . heartbeat_interval ,
374- ) ) )
375- . await ;
376- self . ctx . application . on_logon ( ) . await ;
377- self . ctx . store . increment_target_seq_number ( ) . await ?;
378- } else {
379- error ! ( "received unexpected logon message" ) ;
380- }
381-
342+ let transition = self . state . on_peer_logon ( & mut self . ctx ) . await ?;
343+ self . apply_transition ( transition) . await ;
382344 Ok ( ( ) )
383345 }
384346
@@ -394,26 +356,9 @@ where
394356 . on_logout ( "peer has logged us out" )
395357 . await ;
396358
397- match self . state {
398- // if the session is already disconnected, we have nothing else to do
399- SessionState :: Disconnected ( ..) => { }
400- // if we initiated the logout, preserve the reconnect flag
401- SessionState :: AwaitingLogout ( AwaitingLogoutState { reconnect, .. } ) => {
402- self . state . disconnect_writer ( ) . await ;
403- self . apply_transition ( TransitionResult :: TransitionTo (
404- SessionState :: new_disconnected ( reconnect, "logout completed" ) ,
405- ) )
406- . await ;
407- }
408- // otherwise assume it makes sense to try to reconnect
409- _ => {
410- self . state . disconnect_writer ( ) . await ;
411- self . apply_transition ( TransitionResult :: TransitionTo (
412- SessionState :: new_disconnected ( true , "peer has logged us out" ) ,
413- ) )
414- . await ;
415- }
416- }
359+ self . state . disconnect_writer ( ) . await ;
360+ let transition = self . state . on_peer_logout ( ) ;
361+ self . apply_transition ( transition) . await ;
417362
418363 self . ctx . store . increment_target_seq_number ( ) . await ?;
419364 Ok ( ( ) )
0 commit comments