Skip to content

Commit 53704b1

Browse files
authored
feat: handle messages with incorrect BeginString and comp ID (#181)
* Add test case for handling message with invalid begin string * Fix comp ID checks and add test for incorrect target comp ID handling * Add test case for handling message with invalid sender comp ID
1 parent fb20fb1 commit 53704b1

4 files changed

Lines changed: 181 additions & 26 deletions

File tree

crates/hotfix/src/error.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,18 @@ pub enum MessageVerificationError {
1616

1717
/// The comp ID is different from our expectations.
1818
#[allow(dead_code)]
19-
#[error("incorrect comp id {0}")]
20-
IncorrectCompId(String),
19+
#[error("incorrect comp id {comp_id} ({comp_id_type:?})")]
20+
IncorrectCompId {
21+
comp_id: String,
22+
comp_id_type: CompIdType,
23+
msg_seq_num: u64,
24+
},
25+
}
26+
27+
#[derive(Debug)]
28+
pub enum CompIdType {
29+
Sender,
30+
Target,
2131
}
2232

2333
#[derive(Debug, Error)]

crates/hotfix/src/session.rs

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::message::parser::RawFixMessage;
2525
use crate::store::MessageStore;
2626
use crate::transport::writer::WriterRef;
2727

28-
use crate::error::MessageVerificationError;
28+
use crate::error::{CompIdType, MessageVerificationError};
2929
use crate::message::logout::Logout;
3030
use crate::message::resend_request::ResendRequest;
3131
use crate::message::sequence_reset::SequenceReset;
@@ -318,15 +318,15 @@ impl<M: FixMessage, S: MessageStore> Session<M, S> {
318318
&self,
319319
message: &Message,
320320
) -> std::result::Result<(), MessageVerificationError> {
321-
let begin_string: &str = message.header().get(fix44::BEGIN_STRING).unwrap();
321+
let begin_string: &str = message.header().get(fix44::BEGIN_STRING).unwrap_or("");
322322
if begin_string != self.config.begin_string.as_str() {
323323
return Err(MessageVerificationError::IncorrectBeginString(
324324
begin_string.to_string(),
325325
));
326326
}
327327

328328
let expected_seq_number = self.store.next_target_seq_number();
329-
let actual_seq_number: u64 = message.header().get(fix44::MSG_SEQ_NUM).unwrap();
329+
let actual_seq_number: u64 = message.header().get(fix44::MSG_SEQ_NUM).unwrap_or_default();
330330

331331
match actual_seq_number.cmp(&expected_seq_number) {
332332
Ordering::Greater => {
@@ -344,6 +344,28 @@ impl<M: FixMessage, S: MessageStore> Session<M, S> {
344344
_ => {}
345345
}
346346

347+
// our TargetCompId is always the same as the expected SenderCompId for them
348+
let expected_sender_comp_id: &str = self.config.target_comp_id.as_str();
349+
let actual_sender_comp_id: &str = message.header().get(fix44::SENDER_COMP_ID).unwrap_or("");
350+
if expected_sender_comp_id != actual_sender_comp_id {
351+
return Err(MessageVerificationError::IncorrectCompId {
352+
comp_id: actual_sender_comp_id.to_string(),
353+
comp_id_type: CompIdType::Sender,
354+
msg_seq_num: actual_seq_number,
355+
});
356+
}
357+
358+
// our SenderCompId is always the same as the expected TargetCompId for them
359+
let expected_target_comp_id: &str = self.config.sender_comp_id.as_str();
360+
let actual_target_comp_id: &str = message.header().get(fix44::TARGET_COMP_ID).unwrap_or("");
361+
if expected_target_comp_id != actual_target_comp_id {
362+
return Err(MessageVerificationError::IncorrectCompId {
363+
comp_id: actual_target_comp_id.to_string(),
364+
comp_id_type: CompIdType::Target,
365+
msg_seq_num: actual_seq_number,
366+
});
367+
}
368+
347369
Ok(())
348370
}
349371

@@ -508,21 +530,40 @@ impl<M: FixMessage, S: MessageStore> Session<M, S> {
508530
MessageVerificationError::IncorrectBeginString(begin_string) => {
509531
self.handle_incorrect_begin_string(begin_string).await;
510532
}
511-
MessageVerificationError::IncorrectCompId(comp_id) => {
512-
self.handle_incorrect_comp_id(comp_id).await;
533+
MessageVerificationError::IncorrectCompId {
534+
comp_id,
535+
comp_id_type,
536+
msg_seq_num,
537+
} => {
538+
self.handle_incorrect_comp_id(comp_id, comp_id_type, msg_seq_num)
539+
.await;
513540
}
514541
}
515542
}
516543

517544
async fn handle_incorrect_begin_string(&mut self, received_begin_string: String) {
518-
// TODO: this should be a disconnect (and maybe a reject first?)
519-
// see: https://www.fixtrading.org/standards/fix-session-layer-online/#when-to-terminate-a-fix-connection-by-terminating-the-transport-layer-connection-instead-of-sending-a-logout355
520-
panic!("incorrect begin string received: {received_begin_string}");
521-
}
545+
self.logout_and_terminate(&format!(
546+
"beginString={received_begin_string} is not supported"
547+
))
548+
.await;
549+
}
550+
551+
async fn handle_incorrect_comp_id(
552+
&mut self,
553+
received_comp_id: String,
554+
comp_id_type: CompIdType,
555+
msg_seq_num: u64,
556+
) {
557+
error!(
558+
"rejecting message with incorrect comp ID: {received_comp_id} (type: {comp_id_type:?})"
559+
);
560+
let reject = Reject::new(msg_seq_num)
561+
.session_reject_reason(SessionRejectReason::ValueIsIncorrect)
562+
.text(&format!("invalid comp ID {received_comp_id}"));
563+
self.send_message(reject).await;
522564

523-
async fn handle_incorrect_comp_id(&mut self, received_comp_id: String) {
524-
// TODO: this should also be a disconnect I think (and maybe a reject first?)
525-
panic!("incorrect comp ID received: {received_comp_id}");
565+
self.logout_and_terminate("incorrect comp ID received")
566+
.await;
526567
}
527568

528569
async fn handle_sequence_number_too_low(&mut self, actual: u64, expected: u64) {

crates/hotfix/tests/common/setup.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ use hotfix_message::fix44::MSG_TYPE;
1313
pub const HEARTBEAT_INTERVAL: u64 = 30;
1414
pub const LOGON_TIMEOUT: u64 = 10;
1515

16+
pub const COUNTERPARTY_COMP_ID: &str = "dummy-acceptor";
17+
pub const OUR_COMP_ID: &str = "dummy-initiator";
18+
1619
pub async fn given_a_connected_session() -> (SessionRef<TestMessage>, MockCounterparty<TestMessage>)
1720
{
1821
let message_store = InMemoryMessageStore::default();
@@ -48,8 +51,8 @@ pub async fn given_an_active_session() -> (SessionRef<TestMessage>, MockCounterp
4851
pub fn create_session_config() -> SessionConfig {
4952
SessionConfig {
5053
begin_string: "FIX.4.4".to_string(),
51-
sender_comp_id: "dummy-initiator".to_string(),
52-
target_comp_id: "dummy-acceptor".to_string(),
54+
sender_comp_id: OUR_COMP_ID.to_string(),
55+
target_comp_id: COUNTERPARTY_COMP_ID.to_string(),
5356
data_dictionary_path: None,
5457
connection_host: "".to_string(),
5558
connection_port: 0,

crates/hotfix/tests/session_test_cases/invalid_message_tests.rs

Lines changed: 111 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
use crate::common::actions::when;
22
use crate::common::assertions::then;
3-
use crate::common::setup::given_an_active_session;
3+
use crate::common::setup::{COUNTERPARTY_COMP_ID, OUR_COMP_ID, given_an_active_session};
44
use crate::common::test_messages::{TestMessage, replace_field_value};
55
use hotfix::message::{FixMessage, generate_message};
66
use hotfix::session::Status;
77
use hotfix_message::dict::{FieldLocation, FixDatatype};
8+
use hotfix_message::field_types::Timestamp;
89
use hotfix_message::fix44::MSG_TYPE;
9-
use hotfix_message::message::Message;
10+
use hotfix_message::message::{Config, Message};
1011
use hotfix_message::{HardCodedFixFieldDefinition, Part, fix44};
1112

1213
/// Tests that when a counterparty sends a message containing an invalid/unrecognised field,
@@ -34,8 +35,6 @@ async fn test_garbled_message_with_invalid_target_comp_id_gets_ignored() {
3435

3536
// counterparty sends a message with invalid body length, which constitutes a garbled message
3637
let garbled_message = build_execution_report_with_incorrect_body_length(
37-
"dummy-acceptor",
38-
"dummy-initiator",
3938
mock_counterparty.next_target_sequence_number(),
4039
);
4140
when(&mut mock_counterparty)
@@ -59,6 +58,79 @@ async fn test_garbled_message_with_invalid_target_comp_id_gets_ignored() {
5958
then(&mut mock_counterparty).gets_disconnected().await;
6059
}
6160

61+
/// Tests that when a counterparty sends a message with an invalid BeginString,
62+
/// the session logs out and disconnects.
63+
#[tokio::test]
64+
async fn test_message_with_invalid_begin_string() {
65+
let (_session, mut mock_counterparty) = given_an_active_session().await;
66+
67+
// a message with invalid BeginString is sent by the counterparty
68+
let invalid_message = build_execution_report_with_incorrect_begin_string(
69+
mock_counterparty.next_target_sequence_number(),
70+
);
71+
when(&mut mock_counterparty)
72+
.sends_raw_message(invalid_message)
73+
.await;
74+
75+
// then we log out and disconnect
76+
then(&mut mock_counterparty)
77+
.receives(|msg| assert_eq!(msg.header().get::<&str>(MSG_TYPE).unwrap(), "5"))
78+
.await;
79+
then(&mut mock_counterparty).gets_disconnected().await;
80+
}
81+
82+
/// Tests that when a counterparty sends a message with an invalid TargetCompId,
83+
/// the session sends a Reject (MsgType=3) and logs out and disconnects.
84+
#[tokio::test]
85+
async fn test_message_with_invalid_target_comp_id() {
86+
let (_session, mut mock_counterparty) = given_an_active_session().await;
87+
88+
// a message with incorrect TargetCompId is sent by the counterparty
89+
let invalid_message = build_execution_report_with_comp_id(
90+
mock_counterparty.next_target_sequence_number(),
91+
COUNTERPARTY_COMP_ID,
92+
"WRONG_COMP_ID",
93+
);
94+
when(&mut mock_counterparty)
95+
.sends_raw_message(invalid_message)
96+
.await;
97+
98+
// then we send a reject, log out and disconnect
99+
then(&mut mock_counterparty)
100+
.receives(|msg| assert_eq!(msg.header().get::<&str>(MSG_TYPE).unwrap(), "3"))
101+
.await;
102+
then(&mut mock_counterparty)
103+
.receives(|msg| assert_eq!(msg.header().get::<&str>(MSG_TYPE).unwrap(), "5"))
104+
.await;
105+
then(&mut mock_counterparty).gets_disconnected().await;
106+
}
107+
108+
/// Tests that when a counterparty sends a message with an invalid SenderCompId,
109+
/// the session sends a Reject (MsgType=3) and logs out and disconnects.
110+
#[tokio::test]
111+
async fn test_message_with_invalid_sender_comp_id() {
112+
let (_session, mut mock_counterparty) = given_an_active_session().await;
113+
114+
// a message with incorrect SenderCompId is sent by the counterparty
115+
let invalid_message = build_execution_report_with_comp_id(
116+
mock_counterparty.next_target_sequence_number(),
117+
"WRONG_COMP_ID",
118+
OUR_COMP_ID,
119+
);
120+
when(&mut mock_counterparty)
121+
.sends_raw_message(invalid_message)
122+
.await;
123+
124+
// then we send a reject, log out and disconnect
125+
then(&mut mock_counterparty)
126+
.receives(|msg| assert_eq!(msg.header().get::<&str>(MSG_TYPE).unwrap(), "3"))
127+
.await;
128+
then(&mut mock_counterparty)
129+
.receives(|msg| assert_eq!(msg.header().get::<&str>(MSG_TYPE).unwrap(), "5"))
130+
.await;
131+
then(&mut mock_counterparty).gets_disconnected().await;
132+
}
133+
62134
/// A new order message with an extra, invalid field.
63135
#[derive(Clone, Debug)]
64136
struct ExecutionReportWithInvalidField {
@@ -121,16 +193,45 @@ pub const CUSTOM_FIELD: &HardCodedFixFieldDefinition = &HardCodedFixFieldDefinit
121193
location: FieldLocation::Body,
122194
};
123195

124-
fn build_execution_report_with_incorrect_body_length(
125-
sender_comp_id: &str,
126-
target_comp_id: &str,
127-
msg_seq_num: usize,
128-
) -> Vec<u8> {
196+
fn build_execution_report_with_incorrect_body_length(msg_seq_num: usize) -> Vec<u8> {
129197
let report = TestMessage::dummy_execution_report();
130198
let mut raw_message =
131-
generate_message(sender_comp_id, target_comp_id, msg_seq_num, report).unwrap();
199+
generate_message(COUNTERPARTY_COMP_ID, OUR_COMP_ID, msg_seq_num, report).unwrap();
132200

133201
replace_field_value(&mut raw_message, 9, b"999");
134202

135203
raw_message
136204
}
205+
206+
fn build_execution_report_with_incorrect_begin_string(msg_seq_num: usize) -> Vec<u8> {
207+
let report = TestMessage::dummy_execution_report();
208+
209+
// we expect BeginString FIX.4.4 but this message contains FIX.4.2
210+
let mut msg = Message::new("FIX.4.2", report.message_type());
211+
msg.set(fix44::SENDER_COMP_ID, COUNTERPARTY_COMP_ID);
212+
msg.set(fix44::TARGET_COMP_ID, OUR_COMP_ID);
213+
msg.set(fix44::MSG_SEQ_NUM, msg_seq_num);
214+
msg.set(fix44::SENDING_TIME, Timestamp::utc_now());
215+
216+
report.write(&mut msg);
217+
218+
msg.encode(&Config::default()).unwrap()
219+
}
220+
221+
fn build_execution_report_with_comp_id(
222+
msg_seq_num: usize,
223+
sender_comp_id: &str,
224+
target_comp_id: &str,
225+
) -> Vec<u8> {
226+
let report = TestMessage::dummy_execution_report();
227+
228+
let mut msg = Message::new("FIX.4.4", report.message_type());
229+
msg.set(fix44::SENDER_COMP_ID, sender_comp_id);
230+
msg.set(fix44::TARGET_COMP_ID, target_comp_id);
231+
msg.set(fix44::MSG_SEQ_NUM, msg_seq_num);
232+
msg.set(fix44::SENDING_TIME, Timestamp::utc_now());
233+
234+
report.write(&mut msg);
235+
236+
msg.encode(&Config::default()).unwrap()
237+
}

0 commit comments

Comments
 (0)