Skip to content

Commit 6363930

Browse files
committed
Add light unit test suite for inbound message processing module
1 parent 99e486d commit 6363930

3 files changed

Lines changed: 284 additions & 10 deletions

File tree

crates/hotfix/src/session/inbound.rs

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,229 @@ pub(crate) async fn handle_original_sending_time_missing<A, S: MessageStore>(
163163
error!("failed to increment target seq number: {:?}", err);
164164
}
165165
}
166+
167+
#[cfg(test)]
168+
mod tests {
169+
use super::*;
170+
use crate::session::test_utils::{
171+
FakeMessageStore, create_test_ctx, create_writer, extract_field, extract_msg_type,
172+
};
173+
use crate::transport::writer::WriterMessage;
174+
175+
#[tokio::test]
176+
async fn handle_incorrect_begin_string_returns_transition_to_disconnected() {
177+
let mut ctx = create_test_ctx(FakeMessageStore::new());
178+
let (writer, mut rx) = create_writer();
179+
180+
let result = handle_incorrect_begin_string(&mut ctx, &writer, "FIX.4.0".to_string()).await;
181+
182+
assert!(matches!(
183+
result,
184+
TransitionResult::TransitionTo(SessionState::Disconnected(_))
185+
));
186+
187+
// Should send a Logout containing the bad begin string, then disconnect
188+
let msg = rx.recv().await.unwrap();
189+
match &msg {
190+
WriterMessage::SendMessage(raw) => {
191+
assert_eq!(extract_msg_type(raw.as_bytes()).as_deref(), Some("5"));
192+
let text = extract_field(raw.as_bytes(), 58).expect("expected Text(58) field");
193+
assert!(
194+
text.contains("FIX.4.0"),
195+
"logout text should mention the bad begin string, got: {text}"
196+
);
197+
}
198+
_ => panic!("expected SendMessage, got {msg:?}"),
199+
}
200+
assert!(matches!(
201+
rx.recv().await.unwrap(),
202+
WriterMessage::Disconnect
203+
));
204+
205+
// Sender seq number should have been incremented for the logout
206+
assert_eq!(ctx.store.next_sender_seq, 2);
207+
}
208+
209+
#[tokio::test]
210+
async fn handle_incorrect_comp_id_returns_transition_to_disconnected() {
211+
let mut ctx = create_test_ctx(FakeMessageStore::new());
212+
let (writer, mut rx) = create_writer();
213+
214+
let result = handle_incorrect_comp_id(
215+
&mut ctx,
216+
&writer,
217+
"BAD_COMP".to_string(),
218+
CompIdType::Sender,
219+
1,
220+
)
221+
.await;
222+
223+
assert!(matches!(
224+
result,
225+
TransitionResult::TransitionTo(SessionState::Disconnected(_))
226+
));
227+
228+
// First message: Reject (35=3) mentioning the bad comp ID
229+
let msg = rx.recv().await.unwrap();
230+
match &msg {
231+
WriterMessage::SendMessage(raw) => {
232+
assert_eq!(extract_msg_type(raw.as_bytes()).as_deref(), Some("3"));
233+
let text = extract_field(raw.as_bytes(), 58).expect("expected Text(58) field");
234+
assert!(
235+
text.contains("BAD_COMP"),
236+
"reject text should mention the bad comp ID, got: {text}"
237+
);
238+
}
239+
_ => panic!("expected SendMessage(Reject), got {msg:?}"),
240+
}
241+
242+
// Second message: Logout (35=5)
243+
let msg = rx.recv().await.unwrap();
244+
match &msg {
245+
WriterMessage::SendMessage(raw) => {
246+
assert_eq!(extract_msg_type(raw.as_bytes()).as_deref(), Some("5"));
247+
}
248+
_ => panic!("expected SendMessage(Logout), got {msg:?}"),
249+
}
250+
251+
// Third: Disconnect
252+
assert!(matches!(
253+
rx.recv().await.unwrap(),
254+
WriterMessage::Disconnect
255+
));
256+
257+
// Sender seq incremented twice (reject + logout)
258+
assert_eq!(ctx.store.next_sender_seq, 3);
259+
}
260+
261+
#[tokio::test]
262+
async fn handle_sequence_number_too_low_possible_duplicate_returns_stay() {
263+
let mut ctx = create_test_ctx(FakeMessageStore::new());
264+
let (writer, mut rx) = create_writer();
265+
266+
let result = handle_sequence_number_too_low(&mut ctx, &writer, 5, 1, true).await;
267+
268+
assert!(matches!(result, TransitionResult::Stay));
269+
270+
// No messages should have been sent
271+
assert!(rx.try_recv().is_err());
272+
273+
// Store should be untouched
274+
assert_eq!(ctx.store.next_sender_seq, 1);
275+
assert_eq!(ctx.store.next_target_seq, 1);
276+
}
277+
278+
#[tokio::test]
279+
async fn handle_sequence_number_too_low_returns_transition_to_disconnected_without_reconnect() {
280+
let mut ctx = create_test_ctx(FakeMessageStore::new());
281+
let (writer, mut rx) = create_writer();
282+
283+
let result = handle_sequence_number_too_low(&mut ctx, &writer, 5, 1, false).await;
284+
285+
match result {
286+
TransitionResult::TransitionTo(state) => {
287+
assert!(!state.should_reconnect());
288+
}
289+
TransitionResult::Stay => panic!("expected TransitionTo(Disconnected)"),
290+
}
291+
292+
// Should send a Logout mentioning the sequence mismatch, then disconnect
293+
let msg = rx.recv().await.unwrap();
294+
match &msg {
295+
WriterMessage::SendMessage(raw) => {
296+
assert_eq!(extract_msg_type(raw.as_bytes()).as_deref(), Some("5"));
297+
let text = extract_field(raw.as_bytes(), 58).expect("expected Text(58) field");
298+
assert!(
299+
text.contains("5") && text.contains("1"),
300+
"logout text should mention expected/actual seq nums, got: {text}"
301+
);
302+
}
303+
_ => panic!("expected SendMessage(Logout), got {msg:?}"),
304+
}
305+
assert!(matches!(
306+
rx.recv().await.unwrap(),
307+
WriterMessage::Disconnect
308+
));
309+
310+
assert_eq!(ctx.store.next_sender_seq, 2);
311+
}
312+
313+
#[tokio::test]
314+
async fn handle_sending_time_accuracy_problem_sends_reject() {
315+
let mut ctx = create_test_ctx(FakeMessageStore::new());
316+
let (writer, mut rx) = create_writer();
317+
318+
handle_sending_time_accuracy_problem(&mut ctx, &writer, 42, "bad time").await;
319+
320+
let msg = rx.recv().await.unwrap();
321+
match &msg {
322+
WriterMessage::SendMessage(raw) => {
323+
assert_eq!(extract_msg_type(raw.as_bytes()).as_deref(), Some("3"));
324+
let text = extract_field(raw.as_bytes(), 58).expect("expected Text(58) field");
325+
assert!(
326+
text.contains("bad time"),
327+
"reject text should contain the provided text, got: {text}"
328+
);
329+
}
330+
_ => panic!("expected SendMessage(Reject), got {msg:?}"),
331+
}
332+
333+
// Target seq number should have been incremented
334+
assert_eq!(ctx.store.next_target_seq, 2);
335+
// Sender seq number should have been incremented for the outbound reject
336+
assert_eq!(ctx.store.next_sender_seq, 2);
337+
}
338+
339+
#[tokio::test]
340+
async fn handle_original_sending_time_missing_sends_reject() {
341+
let mut ctx = create_test_ctx(FakeMessageStore::new());
342+
let (writer, mut rx) = create_writer();
343+
344+
handle_original_sending_time_missing(&mut ctx, &writer, 7).await;
345+
346+
let msg = rx.recv().await.unwrap();
347+
match &msg {
348+
WriterMessage::SendMessage(raw) => {
349+
assert_eq!(extract_msg_type(raw.as_bytes()).as_deref(), Some("3"));
350+
let text = extract_field(raw.as_bytes(), 58).expect("expected Text(58) field");
351+
assert!(
352+
text.contains("original sending time"),
353+
"reject text should mention original sending time, got: {text}"
354+
);
355+
}
356+
_ => panic!("expected SendMessage(Reject), got {msg:?}"),
357+
}
358+
359+
// Both sender and target seq numbers should have been incremented
360+
assert_eq!(ctx.store.next_sender_seq, 2);
361+
assert_eq!(ctx.store.next_target_seq, 2);
362+
}
363+
364+
#[tokio::test]
365+
async fn handle_invalid_msg_type_sends_reject_for_message_with_seq_num() {
366+
let mut ctx = create_test_ctx(FakeMessageStore::new());
367+
let (writer, mut rx) = create_writer();
368+
369+
let mut message = Message::new("FIX.4.4", "ZZ");
370+
message.header_mut().set(MSG_SEQ_NUM, 1u64);
371+
372+
handle_invalid_msg_type(&mut ctx, &writer, &message, "ZZ").await;
373+
374+
let msg = rx.recv().await.unwrap();
375+
match &msg {
376+
WriterMessage::SendMessage(raw) => {
377+
assert_eq!(extract_msg_type(raw.as_bytes()).as_deref(), Some("3"));
378+
let text = extract_field(raw.as_bytes(), 58).expect("expected Text(58) field");
379+
assert!(
380+
text.contains("ZZ"),
381+
"reject text should mention the invalid msg type, got: {text}"
382+
);
383+
}
384+
_ => panic!("expected SendMessage(Reject), got {msg:?}"),
385+
}
386+
387+
// Sender seq incremented for the reject, target seq incremented because msg seq matched
388+
assert_eq!(ctx.store.next_sender_seq, 2);
389+
assert_eq!(ctx.store.next_target_seq, 2);
390+
}
391+
}

crates/hotfix/src/session/outbound.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,8 @@ mod tests {
136136

137137
#[tokio::test]
138138
async fn resend_messages_returns_error_for_garbled_stored_message() {
139-
let store = GarbledMessageStore {
140-
messages: vec![b"not a valid FIX message".to_vec()],
141-
};
139+
let mut store = FakeMessageStore::new();
140+
store.messages = vec![b"not a valid FIX message".to_vec()];
142141
let mut ctx = create_test_ctx(store);
143142
let (sender, _receiver) = mpsc::channel(10);
144143
let writer = WriterRef::new(sender);

crates/hotfix/src/session/test_utils.rs

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,74 @@
11
use crate::config::SessionConfig;
22
use crate::session::ctx::SessionCtx;
33
use crate::store::{MessageStore, Result as StoreResult};
4+
use crate::transport::writer::{WriterMessage, WriterRef};
45
use chrono::{DateTime, Utc};
56
use hotfix_message::MessageBuilder;
67
use hotfix_message::dict::Dictionary;
78
use hotfix_message::message::Config as MessageConfig;
9+
use tokio::sync::mpsc;
810

911
#[derive(Clone)]
10-
pub(crate) struct GarbledMessageStore {
12+
pub(crate) struct FakeMessageStore {
1113
pub(crate) messages: Vec<Vec<u8>>,
14+
pub(crate) next_sender_seq: u64,
15+
pub(crate) next_target_seq: u64,
16+
}
17+
18+
impl FakeMessageStore {
19+
pub(crate) fn new() -> Self {
20+
Self {
21+
messages: vec![],
22+
next_sender_seq: 1,
23+
next_target_seq: 1,
24+
}
25+
}
1226
}
1327

1428
#[async_trait::async_trait]
15-
impl MessageStore for GarbledMessageStore {
16-
async fn add(&mut self, _: u64, _: &[u8]) -> StoreResult<()> {
29+
impl MessageStore for FakeMessageStore {
30+
async fn add(&mut self, _: u64, msg: &[u8]) -> StoreResult<()> {
31+
self.messages.push(msg.to_vec());
1732
Ok(())
1833
}
1934
async fn get_slice(&self, _: usize, _: usize) -> StoreResult<Vec<Vec<u8>>> {
2035
Ok(self.messages.clone())
2136
}
2237
fn next_sender_seq_number(&self) -> u64 {
23-
1
38+
self.next_sender_seq
2439
}
2540
fn next_target_seq_number(&self) -> u64 {
26-
1
41+
self.next_target_seq
2742
}
2843
async fn increment_sender_seq_number(&mut self) -> StoreResult<()> {
44+
self.next_sender_seq += 1;
2945
Ok(())
3046
}
3147
async fn increment_target_seq_number(&mut self) -> StoreResult<()> {
48+
self.next_target_seq += 1;
3249
Ok(())
3350
}
34-
async fn set_target_seq_number(&mut self, _: u64) -> StoreResult<()> {
51+
async fn set_target_seq_number(&mut self, seq: u64) -> StoreResult<()> {
52+
self.next_target_seq = seq;
3553
Ok(())
3654
}
3755
async fn reset(&mut self) -> StoreResult<()> {
56+
self.messages.clear();
57+
self.next_sender_seq = 1;
58+
self.next_target_seq = 1;
3859
Ok(())
3960
}
4061
fn creation_time(&self) -> DateTime<Utc> {
4162
Utc::now()
4263
}
4364
}
4465

45-
pub(crate) fn create_test_ctx(store: GarbledMessageStore) -> SessionCtx<(), GarbledMessageStore> {
66+
pub(crate) fn create_writer() -> (WriterRef, mpsc::Receiver<WriterMessage>) {
67+
let (sender, receiver) = mpsc::channel(16);
68+
(WriterRef::new(sender), receiver)
69+
}
70+
71+
pub(crate) fn create_test_ctx(store: FakeMessageStore) -> SessionCtx<(), FakeMessageStore> {
4672
let message_config = MessageConfig::default();
4773
let dictionary = Dictionary::fix44();
4874
let message_builder = MessageBuilder::new(dictionary, message_config).unwrap();
@@ -68,3 +94,26 @@ pub(crate) fn create_test_ctx(store: GarbledMessageStore) -> SessionCtx<(), Garb
6894
message_config,
6995
}
7096
}
97+
98+
/// Extract the FIX message type (tag 35) from a raw FIX message bytes.
99+
pub(crate) fn extract_msg_type(raw: &[u8]) -> Option<String> {
100+
let s = std::str::from_utf8(raw).ok()?;
101+
for field in s.split('\x01') {
102+
if let Some(value) = field.strip_prefix("35=") {
103+
return Some(value.to_string());
104+
}
105+
}
106+
None
107+
}
108+
109+
/// Extract a string field value by tag number from raw FIX message bytes.
110+
pub(crate) fn extract_field(raw: &[u8], tag: u32) -> Option<String> {
111+
let s = std::str::from_utf8(raw).ok()?;
112+
let prefix = format!("{tag}=");
113+
for field in s.split('\x01') {
114+
if let Some(value) = field.strip_prefix(&prefix) {
115+
return Some(value.to_string());
116+
}
117+
}
118+
None
119+
}

0 commit comments

Comments
 (0)