|
| 1 | +use crate::message::heartbeat::Heartbeat; |
1 | 2 | use crate::message::logout::Logout; |
2 | 3 | use crate::message::reject::Reject; |
3 | 4 | use crate::message::verification::verify_message; |
4 | 5 | use crate::message::verification_issue::{CompIdType, MessageError, VerificationIssue}; |
5 | 6 | use crate::session::ctx::{SessionCtx, TransitionResult}; |
| 7 | +use crate::session::error::{InternalSendResultExt, SessionOperationError}; |
| 8 | +use crate::session::get_msg_seq_num; |
6 | 9 | use crate::session::outbound; |
7 | 10 | use crate::session::state::SessionState; |
8 | 11 | use crate::transport::writer::WriterRef; |
9 | 12 | use hotfix_message::Part; |
10 | 13 | use hotfix_message::message::Message; |
11 | | -use hotfix_message::session_fields::{MSG_SEQ_NUM, SessionRejectReason}; |
| 14 | +use hotfix_message::session_fields::{ |
| 15 | + BEGIN_SEQ_NO, END_SEQ_NO, MSG_SEQ_NUM, NEW_SEQ_NO, SessionRejectReason, TEST_REQ_ID, |
| 16 | +}; |
12 | 17 | use hotfix_store::MessageStore; |
13 | 18 | use tracing::error; |
14 | 19 | use tracing::warn; |
@@ -256,6 +261,124 @@ async fn handle_verification_error<A, S: MessageStore>( |
256 | 261 | } |
257 | 262 | } |
258 | 263 |
|
| 264 | +pub(crate) async fn on_test_request<A, S: MessageStore>( |
| 265 | + ctx: &mut SessionCtx<A, S>, |
| 266 | + writer: &WriterRef, |
| 267 | + message: &Message, |
| 268 | +) -> Result<(), SessionOperationError> { |
| 269 | + let req_id: &str = message.get(TEST_REQ_ID).unwrap_or_else(|_| { |
| 270 | + // TODO: send reject? |
| 271 | + todo!() |
| 272 | + }); |
| 273 | + |
| 274 | + ctx.store.increment_target_seq_number().await?; |
| 275 | + |
| 276 | + outbound::send_message(ctx, writer, Heartbeat::for_request(req_id.to_string())) |
| 277 | + .await |
| 278 | + .with_send_context("heartbeat response")?; |
| 279 | + |
| 280 | + Ok(()) |
| 281 | +} |
| 282 | + |
| 283 | +pub(crate) async fn on_sequence_reset<A, S: MessageStore>( |
| 284 | + ctx: &mut SessionCtx<A, S>, |
| 285 | + writer: &WriterRef, |
| 286 | + message: &Message, |
| 287 | +) -> Result<(), SessionOperationError> { |
| 288 | + let msg_seq_num = get_msg_seq_num(message); |
| 289 | + |
| 290 | + let end: u64 = match message.get(NEW_SEQ_NO) { |
| 291 | + Ok(new_seq_no) => new_seq_no, |
| 292 | + Err(err) => { |
| 293 | + error!( |
| 294 | + "received sequence reset message without new sequence number: {:?}", |
| 295 | + err |
| 296 | + ); |
| 297 | + let reject = Reject::new(msg_seq_num) |
| 298 | + .session_reject_reason(SessionRejectReason::RequiredTagMissing) |
| 299 | + .text("missing NewSeqNo tag in sequence reset message"); |
| 300 | + outbound::send_message(ctx, writer, reject) |
| 301 | + .await |
| 302 | + .with_send_context("reject for missing NEW_SEQ_NO")?; |
| 303 | + |
| 304 | + // note: we don't increment the target seq number here |
| 305 | + // this is an ambiguous case in the specification, but leaving the |
| 306 | + // sequence number as is feels the safest |
| 307 | + return Ok(()); |
| 308 | + } |
| 309 | + }; |
| 310 | + |
| 311 | + // sequence resets cannot move the target seq number backwards |
| 312 | + // regardless of whether the message is a gap fill or not |
| 313 | + if end <= ctx.store.next_target_seq_number() { |
| 314 | + error!( |
| 315 | + "received sequence reset message which would move target seq number backwards: {end}", |
| 316 | + ); |
| 317 | + let text = format!("attempt to lower sequence number, invalid value NewSeqNo(36)={end}"); |
| 318 | + let reject = Reject::new(msg_seq_num) |
| 319 | + .session_reject_reason(SessionRejectReason::ValueIsIncorrect) |
| 320 | + .text(&text); |
| 321 | + outbound::send_message(ctx, writer, reject) |
| 322 | + .await |
| 323 | + .with_send_context("reject for invalid sequence reset")?; |
| 324 | + return Ok(()); |
| 325 | + } |
| 326 | + |
| 327 | + ctx.store.set_target_seq_number(end - 1).await?; |
| 328 | + Ok(()) |
| 329 | +} |
| 330 | + |
| 331 | +pub(crate) async fn on_resend_request<A, S: MessageStore>( |
| 332 | + ctx: &mut SessionCtx<A, S>, |
| 333 | + writer: &WriterRef, |
| 334 | + message: &Message, |
| 335 | +) -> Result<(), SessionOperationError> { |
| 336 | + let msg_seq_num = get_msg_seq_num(message); |
| 337 | + let expected = ctx.store.next_target_seq_number(); |
| 338 | + |
| 339 | + let begin_seq_number: u64 = match message.get(BEGIN_SEQ_NO) { |
| 340 | + Ok(seq_number) => seq_number, |
| 341 | + Err(_) => { |
| 342 | + let reject = Reject::new(msg_seq_num) |
| 343 | + .session_reject_reason(SessionRejectReason::RequiredTagMissing) |
| 344 | + .text("missing begin sequence number for resend request"); |
| 345 | + outbound::send_message(ctx, writer, reject) |
| 346 | + .await |
| 347 | + .with_send_context("reject for missing BEGIN_SEQ_NO")?; |
| 348 | + return Ok(()); |
| 349 | + } |
| 350 | + }; |
| 351 | + |
| 352 | + let end_seq_number: u64 = match message.get(END_SEQ_NO) { |
| 353 | + Ok(seq_number) => { |
| 354 | + let last_seq_number = ctx.store.next_sender_seq_number() - 1; |
| 355 | + if seq_number == 0 { |
| 356 | + last_seq_number |
| 357 | + } else { |
| 358 | + std::cmp::min(seq_number, last_seq_number) |
| 359 | + } |
| 360 | + } |
| 361 | + Err(_) => { |
| 362 | + let reject = Reject::new(msg_seq_num) |
| 363 | + .session_reject_reason(SessionRejectReason::RequiredTagMissing) |
| 364 | + .text("missing end sequence number for resend request"); |
| 365 | + outbound::send_message(ctx, writer, reject) |
| 366 | + .await |
| 367 | + .with_send_context("reject for missing END_SEQ_NO")?; |
| 368 | + return Ok(()); |
| 369 | + } |
| 370 | + }; |
| 371 | + |
| 372 | + // Only increment target seq if seq matches expected |
| 373 | + if msg_seq_num == expected { |
| 374 | + ctx.store.increment_target_seq_number().await?; |
| 375 | + } |
| 376 | + |
| 377 | + outbound::resend_messages(ctx, writer, begin_seq_number, end_seq_number).await?; |
| 378 | + |
| 379 | + Ok(()) |
| 380 | +} |
| 381 | + |
259 | 382 | #[cfg(test)] |
260 | 383 | mod tests { |
261 | 384 | use super::*; |
|
0 commit comments