diff --git a/peri-middlewares/src/mcp/callback_server_test.rs b/peri-middlewares/src/mcp/callback_server_test.rs index eb31d980..cb1976ec 100644 --- a/peri-middlewares/src/mcp/callback_server_test.rs +++ b/peri-middlewares/src/mcp/callback_server_test.rs @@ -75,21 +75,42 @@ async fn test_bind_multiple_servers() { drop(s2); } +/// #16:set_state 应该更新内部 state_param, +/// 使 wait_for_code 通过 parse_callback_url 严格校验 state 一致。 #[tokio::test] -async fn test_set_state_enables_callback_validation() { +async fn test_set_state_updates_validation() { let (mut server, _uri) = OAuthCallbackServer::bind().await.unwrap(); - // 默认 state_param 为空,set_state 之前 parse_callback_url 因 expected_state - // 为空会跳过校验。注入真实 state 后,state 不匹配应被拒绝。 - server.set_state("expected-csrf".to_string()); - assert_eq!(server.state_param, "expected-csrf"); + // 默认 state_param 为空字符串 + assert_eq!(server.state_param, ""); + server.set_state("csrf-token-123".to_string()); + assert_eq!(server.state_param, "csrf-token-123"); // 空 state 不应被注入(防止误把默认值覆盖成空) server.set_state(String::new()); - assert_eq!(server.state_param, "expected-csrf"); + assert_eq!(server.state_param, "csrf-token-123"); +} - let mismatch = parse_callback_url("/callback?code=c&state=other", &server.state_param); - assert!(mismatch.is_err(), "state 不匹配应被拒绝"); +/// #16:state 校验函数在 expected_state 非空时应严格匹配, +/// 不一致时返回 ParseFailed 错误(CSRF 防御)。 +#[test] +fn test_parse_callback_url_strict_state_validation_when_nonempty() { + // 一致 → 通过 + let ok = parse_callback_url("/cb?code=c&state=abc", "abc"); + assert!(ok.is_ok(), "state 一致应通过: {:?}", ok); - let ok = parse_callback_url("/callback?code=c&state=expected-csrf", &server.state_param); - assert!(ok.is_ok(), "state 匹配应通过"); + // 不一致 → 拒绝 + let mismatch = parse_callback_url("/cb?code=c&state=xyz", "abc"); + assert!(mismatch.is_err(), "state 不一致应拒绝"); + let err_msg = mismatch.unwrap_err().to_string(); + assert!( + err_msg.contains("CSRF") || err_msg.contains("state"), + "错误信息应提及 CSRF/state: {err_msg}" + ); +} + +/// #16:expected_state 为空时仍跳过校验(向后兼容老测试 / 未设置场景)。 +#[test] +fn test_parse_callback_url_skip_validation_when_expected_empty() { + let result = parse_callback_url("/cb?code=c&state=anything", ""); + assert!(result.is_ok(), "expected_state 为空时不应强制校验"); } diff --git a/peri-middlewares/src/mcp/oauth_flow_test.rs b/peri-middlewares/src/mcp/oauth_flow_test.rs index 1361107b..b2170f7e 100644 --- a/peri-middlewares/src/mcp/oauth_flow_test.rs +++ b/peri-middlewares/src/mcp/oauth_flow_test.rs @@ -80,3 +80,5 @@ }); assert_eq!(counter.load(Ordering::SeqCst), 1); } + + diff --git a/peri-tui/src/app/history_persistence.rs b/peri-tui/src/app/history_persistence.rs index 9de9c0d0..51823f98 100644 --- a/peri-tui/src/app/history_persistence.rs +++ b/peri-tui/src/app/history_persistence.rs @@ -77,6 +77,11 @@ fn restrict_to_owner_unix(_path: &std::path::Path) {} #[cfg(test)] mod tests { + // 注意:这里不放 `use super::*;`。 + // Windows 下 restrict_to_owner_unix 是 #[cfg(not(unix))] 的空实现, + // 整个 mod tests 在 Windows 下没有任何使用 super::* 内容的代码, + // clippy -D unused-imports 会把 super::* 当作未使用 import 报错。 + #[cfg(unix)] #[test] fn test_restrict_to_owner_unix_file_gets_0600() {