diff --git a/.gitignore b/.gitignore index 214d93d..2491851 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,4 @@ backups/ # Test files (local only) backend/test2/ +backend/ARCHITECTURE.md diff --git a/backend/WS_PHASE_PROTOCOL.md b/backend/WS_PHASE_PROTOCOL.md new file mode 100644 index 0000000..63ad15a --- /dev/null +++ b/backend/WS_PHASE_PROTOCOL.md @@ -0,0 +1,107 @@ +# WebSocket 房间阶段协议(投票阶段信号) + +目标:让前端能实时展示“投票阶段开始/结束”的倒计时,并在投票阶段禁用提交画作入口。 + +## 1) 房间阶段(phase) + +后端通过 Socket 下发的 `phase` 字段取值: +- `active`:提交/观赏阶段(允许提交画作;不一定允许投票) +- `voting`:投票阶段(不允许提交画作;允许投票/撤票) +- `gameover`:结算结束 + +后端落库字段:`rooms.status`。 + +## 2) 事件:sync:state(进入房间时全量同步) + +事件名:`sync:state` + +触发:客户端 `room:join` 后,服务端会 emit 给该 socket。 + +Payload(camelCase,新增字段对旧前端向后兼容): +- `phase: string`(`active`/`voting`/`gameover`) +- `roomId: string` +- `totalItems: number` +- `aiCount: number` +- `turbidity: number` +- `theme: ThemeResponse` +- `items: GameItemData[]` +- `votingStartedAt?: number`(Unix 毫秒时间戳,仅 voting 期存在) +- `votingEndsAt?: number`(Unix 毫秒时间戳,仅 voting 期存在) +- `serverTime: number`(Unix 毫秒时间戳,服务端当前时间) + +示例(voting 中): +```json +{ + "phase": "voting", + "roomId": "ABCD12", + "totalItems": 8, + "aiCount": 2, + "turbidity": 0.4, + "votingStartedAt": 1769948000000, + "votingEndsAt": 1769948045000, + "serverTime": 1769948012345, + "theme": { "...": "..." }, + "items": [] +} +``` + +## 3) 事件:phase:update(房间内广播阶段切换) + +事件名:`phase:update` + +触发:服务端在以下时机会向 `within(roomId)` 广播给同房间所有连接: +- `active -> voting`:进入投票阶段 +- `voting -> active`:投票超时且未结束游戏,重置票数并退出投票阶段 +- `* -> gameover`:胜负判定落库 gameover 后立即广播 + +Payload(camelCase): +- `phase: string` +- `roomId: string` +- `votingStartedAt?: number`(Unix ms) +- `votingEndsAt?: number`(Unix ms) +- `serverTime: number`(Unix ms) + +示例(进入 voting): +```json +{ + "phase": "voting", + "roomId": "ABCD12", + "votingStartedAt": 1769948000000, + "votingEndsAt": 1769948045000, + "serverTime": 1769948000001 +} +``` + +示例(退出 voting 回到 active): +```json +{ + "phase": "active", + "roomId": "ABCD12", + "serverTime": 1769948050000 +} +``` + +示例(gameover): +```json +{ + "phase": "gameover", + "roomId": "ABCD12", + "serverTime": 1769948060000 +} +``` + +## 4) 前端倒计时推荐算法 + +服务端同时提供 `votingEndsAt` 与 `serverTime`,用于抵消客户端/服务器时钟偏差。 + +推荐做法: +1. 记录收到消息时的本地时间 `clientNowAtReceive = Date.now()` 与消息中的 `serverTime`。 +2. 估算服务器与客户端的时间偏移:`offset = serverTime - clientNowAtReceive`。 +3. 后续倒计时使用:`remainingMs = votingEndsAt - (Date.now() + offset)`。 + +若前端不做校时,也可直接:`remainingMs = votingEndsAt - Date.now()`(误差会更大)。 + +## 5) 禁提交与兼容性说明 + +- 后端在 `voting` 阶段会拒绝 `POST /api/rooms/:room_code/drawings`(HTTP 400),因此旧前端即便未禁用按钮,也不会提交成功。 +- 新增字段/新事件均为增量:旧前端忽略未知字段/事件即可继续运行。 diff --git a/backend/src/config.rs b/backend/src/config.rs index cd2b74f..c25f51e 100644 --- a/backend/src/config.rs +++ b/backend/src/config.rs @@ -21,6 +21,14 @@ pub struct Config { pub wechat_mp_secret: Option, pub auth_token_ttl_days: i64, pub dev_auth_enabled: bool, + pub vote_threshold_ratio: f64, + pub vote_min_threshold: i32, + pub human_eliminated_ratio: f64, + pub victory_human_survive_ratio: f64, + pub ai_overflow_delta: i64, + pub min_humans_to_start_voting: i64, + pub voting_duration_seconds: i64, + pub submit_duration_seconds: i64, } impl Config { @@ -101,6 +109,38 @@ impl Config { dev_auth_enabled: std::env::var("DEV_AUTH_ENABLED") .map(|v| v.to_lowercase() == "true") .unwrap_or(false), + vote_threshold_ratio: std::env::var("VOTE_THRESHOLD_RATIO") + .unwrap_or_else(|_| "0.6".to_string()) + .parse() + .context("VOTE_THRESHOLD_RATIO must be a valid float")?, + vote_min_threshold: std::env::var("VOTE_MIN_THRESHOLD") + .unwrap_or_else(|_| "2".to_string()) + .parse() + .context("VOTE_MIN_THRESHOLD must be a valid number")?, + human_eliminated_ratio: std::env::var("HUMAN_ELIMINATED_RATIO") + .unwrap_or_else(|_| "0.4".to_string()) + .parse() + .context("HUMAN_ELIMINATED_RATIO must be a valid float")?, + victory_human_survive_ratio: std::env::var("VICTORY_HUMAN_SURVIVE_RATIO") + .unwrap_or_else(|_| "0.6".to_string()) + .parse() + .context("VICTORY_HUMAN_SURVIVE_RATIO must be a valid float")?, + ai_overflow_delta: std::env::var("AI_OVERFLOW_DELTA") + .unwrap_or_else(|_| "2".to_string()) + .parse() + .context("AI_OVERFLOW_DELTA must be a valid number")?, + min_humans_to_start_voting: std::env::var("MIN_HUMANS_TO_START_VOTING") + .unwrap_or_else(|_| "2".to_string()) + .parse() + .context("MIN_HUMANS_TO_START_VOTING must be a valid number")?, + voting_duration_seconds: std::env::var("VOTING_DURATION_SECONDS") + .unwrap_or_else(|_| "45".to_string()) + .parse() + .context("VOTING_DURATION_SECONDS must be a valid number")?, + submit_duration_seconds: std::env::var("SUBMIT_DURATION_SECONDS") + .unwrap_or_else(|_| "60".to_string()) + .parse() + .context("SUBMIT_DURATION_SECONDS must be a valid number")?, }) } } diff --git a/backend/src/main.rs b/backend/src/main.rs index 39f385c..525f54f 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -3,6 +3,7 @@ use axum::{ routing::{get, post}, Extension, Router, }; +use socketioxide::handler::ConnectHandler; use socketioxide::SocketIo; use sqlx::postgres::PgPoolOptions; use std::{net::SocketAddr, sync::Arc}; @@ -68,7 +69,10 @@ async fn main() -> Result<()> { let (sio_layer, io) = SocketIo::builder().with_state(state.clone()).build_layer(); // 注册 Socket.IO 事件处理器 - io.ns("/", ws::socketio_handler::on_connect); + io.ns( + "/", + ws::socketio_handler::on_connect.with(ws::socketio_handler::auth_middleware), + ); // CORS 配置 let cors = CorsLayer::new() @@ -128,6 +132,7 @@ fn api_routes(state: Arc, io: SocketIo) -> Router { post(routes::drawings::report_drawing), ) .route("/auth/wechat_mp/login", post(routes::auth::wechat_mp_login)) + .route("/auth/guest/login", post(routes::auth::guest_login)) .route("/auth/dev/login", post(routes::dev_auth::dev_login)) .route("/auth/me", get(routes::auth::me)) .route("/auth/logout", post(routes::auth::logout)) diff --git a/backend/src/models/room.rs b/backend/src/models/room.rs index e743f7a..5f468fb 100644 --- a/backend/src/models/room.rs +++ b/backend/src/models/room.rs @@ -3,6 +3,8 @@ use serde::{Deserialize, Serialize}; use sqlx::FromRow; use uuid::Uuid; +use crate::config::Config; + #[derive(Debug, Clone, Serialize, Deserialize, FromRow)] pub struct Room { pub id: Uuid, @@ -20,10 +22,11 @@ pub struct Room { } impl Room { - /// 计算动态投票阈值 (在线人数的 30%) - pub fn vote_threshold(&self) -> i32 { - let dynamic = ((self.online_count as f64) * 0.3).ceil() as i32; - std::cmp::max(4, dynamic) + /// 计算动态投票阈值(在线人数 * 配置比例,且不低于最小阈值) + pub fn vote_threshold(&self, config: &Config) -> i32 { + let ratio = config.vote_threshold_ratio.clamp(0.0, 1.0); + let dynamic = ((self.online_count as f64) * ratio).ceil() as i32; + std::cmp::max(config.vote_min_threshold, dynamic) } } diff --git a/backend/src/routes/auth.rs b/backend/src/routes/auth.rs index 4b69af8..ec0bb9e 100644 --- a/backend/src/routes/auth.rs +++ b/backend/src/routes/auth.rs @@ -24,6 +24,28 @@ pub struct WechatMpLoginResponse { pub is_new_user: bool, } +#[derive(Debug, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GuestLoginRequest { + #[serde(default, alias = "deviceToken")] + pub device_token: Option, + // Deprecated compatibility field; ignored for identity binding. + #[serde(default, alias = "device_id")] + pub device_id: Option, + #[serde(default, alias = "session_id", alias = "legacySessionId")] + pub session_id: Option, +} + +#[derive(Debug, serde::Serialize)] +#[serde(rename_all = "camelCase")] +pub struct GuestLoginResponse { + pub token: String, + pub user_id: String, + pub is_new_user: bool, + pub is_guest: bool, + pub device_token: String, +} + pub async fn wechat_mp_login( State(state): State>, Json(req): Json, @@ -43,6 +65,28 @@ pub async fn wechat_mp_login( })) } +pub async fn guest_login( + State(state): State>, + Json(req): Json, +) -> Result, ApiError> { + let requested_device_token = req.device_token.as_deref().or(req.device_id.as_deref()); + let result = auth::login_guest_device( + &state.db, + &state.config, + requested_device_token, + req.session_id.as_deref(), + ) + .await?; + + Ok(Json(GuestLoginResponse { + token: result.token, + user_id: result.user_id.to_string(), + is_new_user: result.is_new_user, + is_guest: true, + device_token: result.device_token, + })) +} + pub async fn me( State(state): State>, TypedHeader(auth_header): TypedHeader>, diff --git a/backend/src/routes/drawings.rs b/backend/src/routes/drawings.rs index b7ce3d7..19a0218 100644 --- a/backend/src/routes/drawings.rs +++ b/backend/src/routes/drawings.rs @@ -4,6 +4,7 @@ use axum::{ response::{IntoResponse, Response}, Extension, Json, }; +use chrono::Utc; use rand::{rngs::StdRng, Rng, SeedableRng}; use socketioxide::SocketIo; use std::sync::Arc; @@ -33,6 +34,11 @@ pub async fn create_drawing( if room.status != "active" { return Err(ApiError::BadRequest("Room is not active".to_string())); } + if let Some(voting_ends_at) = room.voting_ends_at { + if Utc::now() < voting_ends_at { + return Err(ApiError::BadRequest("Voting is in progress".to_string())); + } + } // 获取主题配置 let theme: Theme = sqlx::query_as("SELECT * FROM themes WHERE id = $1") @@ -250,7 +256,7 @@ pub async fn vote_drawing( .fetch_one(&state.db) .await?; - let threshold = room.vote_threshold(); + let threshold = room.vote_threshold(&state.config); // 注意: Socket.IO 广播由 socketio_handler 的 vote:cast 事件处理 diff --git a/backend/src/routes/themes.rs b/backend/src/routes/themes.rs index 04cb227..e42f72f 100644 --- a/backend/src/routes/themes.rs +++ b/backend/src/routes/themes.rs @@ -48,7 +48,7 @@ pub async fn get_or_create_room_by_theme( // 查找该主题已有的活跃房间 let existing_room: Option = sqlx::query_as( - "SELECT r.* FROM rooms r WHERE r.theme_id = $1 AND r.status = 'active' ORDER BY r.created_at DESC LIMIT 1" + "SELECT r.* FROM rooms r WHERE r.theme_id = $1 AND r.status IN ('active', 'voting') ORDER BY r.created_at DESC LIMIT 1" ) .bind(theme.id) .fetch_optional(&state.db) diff --git a/backend/src/services/auth.rs b/backend/src/services/auth.rs index f13a9ff..7adcf88 100644 --- a/backend/src/services/auth.rs +++ b/backend/src/services/auth.rs @@ -10,6 +10,7 @@ use crate::services::ApiError; pub const PROVIDER_WECHAT_MINIPROGRAM: &str = "wechat_miniprogram"; pub const PROVIDER_DEV: &str = "dev"; +pub const PROVIDER_GUEST_DEVICE: &str = "guest_device"; pub struct LoginResult { pub token: String, @@ -17,6 +18,13 @@ pub struct LoginResult { pub is_new_user: bool, } +pub struct GuestLoginResult { + pub token: String, + pub user_id: Uuid, + pub is_new_user: bool, + pub device_token: String, +} + pub struct WechatSession { pub openid: String, pub unionid: Option, @@ -36,6 +44,25 @@ fn generate_token() -> String { base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes) } +fn is_valid_guest_device_token(token: &str) -> bool { + token.len() <= 120 + && token.starts_with("guest_device:") + && token + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == ':' || c == '-' || c == '_') +} + +fn normalize_guest_openid(device_token: Option<&str>) -> String { + if let Some(device_token) = device_token { + let trimmed = device_token.trim(); + if !trimmed.is_empty() && is_valid_guest_device_token(trimmed) { + return trimmed.to_string(); + } + } + + format!("guest_device:{}", Uuid::new_v4()) +} + async fn wechat_code_to_session(config: &Config, code: &str) -> Result { let Some(appid) = config.wechat_mp_appid.as_deref() else { return Err(ApiError::BadRequest( @@ -267,6 +294,79 @@ pub async fn login_dev( }) } +pub async fn login_guest_device( + db: &PgPool, + config: &Config, + device_token: Option<&str>, + legacy_session_id: Option<&str>, +) -> Result { + let appid = "standalone"; + let openid = normalize_guest_openid(device_token); + + let mut tx = db.begin().await?; + + let existing_user_id: Option = sqlx::query_scalar( + r#" + SELECT user_id + FROM auth_identities + WHERE provider = $1 AND appid = $2 AND openid = $3 + "#, + ) + .bind(PROVIDER_GUEST_DEVICE) + .bind(appid) + .bind(&openid) + .fetch_optional(&mut *tx) + .await?; + + let (user_id, is_new_user) = if let Some(user_id) = existing_user_id { + (user_id, false) + } else { + let user_id: Uuid = sqlx::query_scalar("INSERT INTO users DEFAULT VALUES RETURNING id") + .fetch_one(&mut *tx) + .await?; + + let _ = sqlx::query( + r#" + INSERT INTO auth_identities (user_id, provider, appid, openid, unionid) + VALUES ($1, $2, $3, $4, NULL) + "#, + ) + .bind(user_id) + .bind(PROVIDER_GUEST_DEVICE) + .bind(appid) + .bind(&openid) + .execute(&mut *tx) + .await?; + + (user_id, true) + }; + + let token = generate_token(); + let expires_at = Utc::now() + Duration::days(config.auth_token_ttl_days); + + let _ = sqlx::query( + r#" + INSERT INTO auth_sessions (token, user_id, legacy_session_id, expires_at) + VALUES ($1, $2, $3, $4) + "#, + ) + .bind(&token) + .bind(user_id) + .bind(legacy_session_id) + .bind(expires_at) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + + Ok(GuestLoginResult { + token, + user_id, + is_new_user, + device_token: openid, + }) +} + pub async fn user_id_from_token(db: &PgPool, token: &str) -> Result { let user_id: Option = sqlx::query_scalar( r#" @@ -318,3 +418,35 @@ pub async fn list_identities_for_user( .await?; Ok(rows) } + +#[cfg(test)] +mod tests { + use super::{is_valid_guest_device_token, normalize_guest_openid}; + + #[test] + fn normalize_guest_openid_accepts_existing_server_token() { + let openid = normalize_guest_openid(Some(" guest_device:abc-123 ")); + assert_eq!(openid, "guest_device:abc-123"); + } + + #[test] + fn normalize_guest_openid_rejects_plain_device_id_and_generates_server_token() { + let openid = normalize_guest_openid(Some("ios-device-raw-id")); + assert!(openid.starts_with("guest_device:")); + assert_ne!(openid, "ios-device-raw-id"); + } + + #[test] + fn normalize_guest_openid_falls_back_to_generated_server_token_for_empty_input() { + let openid = normalize_guest_openid(Some(" ")); + assert!(openid.starts_with("guest_device:")); + assert!(openid.len() > "guest_device:".len()); + } + + #[test] + fn validate_guest_device_token_format() { + assert!(is_valid_guest_device_token("guest_device:ok-123_abc")); + assert!(!is_valid_guest_device_token("device_123")); + assert!(!is_valid_guest_device_token("guest_device:contains space")); + } +} diff --git a/backend/src/ws/game_rules.rs b/backend/src/ws/game_rules.rs new file mode 100644 index 0000000..6fd5b73 --- /dev/null +++ b/backend/src/ws/game_rules.rs @@ -0,0 +1,48 @@ +pub fn clamp_ratio(r: f64) -> f64 { + r.clamp(0.0, 1.0) +} + +pub fn ceil_ratio(total: i64, ratio: f64) -> i64 { + ((total as f64) * clamp_ratio(ratio)).ceil() as i64 +} + +pub fn human_eliminated_limit(human_total: i64, ratio: f64) -> i64 { + std::cmp::max(1, ceil_ratio(human_total, ratio)) +} + +pub fn min_human_survive(human_total: i64, ratio: f64) -> i64 { + std::cmp::max(1, ceil_ratio(human_total, ratio)) +} + +pub fn ai_overflow(ai_alive: i64, human_alive: i64, delta: i64) -> bool { + let delta = delta.max(0); + ai_alive > human_alive + delta +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_human_eliminated_limit_small_room() { + assert_eq!(human_eliminated_limit(1, 0.4), 1); + assert_eq!(human_eliminated_limit(2, 0.4), 1); + assert_eq!(human_eliminated_limit(3, 0.4), 2); + assert_eq!(human_eliminated_limit(5, 0.4), 2); + } + + #[test] + fn test_min_human_survive_small_room() { + assert_eq!(min_human_survive(1, 0.6), 1); + assert_eq!(min_human_survive(2, 0.6), 2); + assert_eq!(min_human_survive(3, 0.6), 2); + assert_eq!(min_human_survive(5, 0.6), 3); + } + + #[test] + fn test_ai_overflow_delta() { + assert!(!ai_overflow(1, 1, 2)); + assert!(!ai_overflow(3, 1, 2)); + assert!(ai_overflow(4, 1, 2)); + } +} diff --git a/backend/src/ws/mod.rs b/backend/src/ws/mod.rs index 123e7ab..6e0852e 100644 --- a/backend/src/ws/mod.rs +++ b/backend/src/ws/mod.rs @@ -1,3 +1,4 @@ +pub mod game_rules; pub mod socketio_handler; pub use socketio_handler::GameItemData; diff --git a/backend/src/ws/socketio_handler.rs b/backend/src/ws/socketio_handler.rs index ee8f65b..dc161d7 100644 --- a/backend/src/ws/socketio_handler.rs +++ b/backend/src/ws/socketio_handler.rs @@ -1,12 +1,15 @@ //! Socket.IO 事件处理器 //! 兼容前端 socket.io-client +use chrono::{Duration, Utc}; use socketioxide::extract::{Data, SocketRef, State as SioState}; use std::sync::Arc; use tracing::{debug, info}; +use uuid::Uuid; use crate::models::{drawing_image_url, Drawing, DrawingItemRow, Room, Theme, ThemeResponse}; -use crate::services::AppState; +use crate::services::{auth, AppState}; +use crate::ws::game_rules; /// 存储在 socket extensions 中的会话信息 #[derive(Clone)] @@ -14,6 +17,39 @@ pub struct RoomSession { pub room_code: String, } +/// 存储在 socket extensions 中的鉴权信息 +#[derive(Clone)] +pub struct AuthSession { + pub user_id: Uuid, +} + +#[derive(Debug)] +pub struct WsAuthError(String); + +impl std::fmt::Display for WsAuthError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +impl std::error::Error for WsAuthError {} + +pub async fn auth_middleware( + socket: SocketRef, + state: SioState>, + Data(auth_data): Data, +) -> Result<(), WsAuthError> { + let token = extract_auth_token(&auth_data) + .ok_or_else(|| WsAuthError("missing auth token".to_string()))?; + + let user_id = auth::user_id_from_token(&state.db, &token) + .await + .map_err(|_| WsAuthError("invalid auth token".to_string()))?; + + socket.extensions.insert(AuthSession { user_id }); + Ok(()) +} + /// Socket.IO 连接处理 pub fn on_connect(socket: SocketRef, _state: SioState>) { let session_id = socket.id.to_string(); @@ -24,6 +60,7 @@ pub fn on_connect(socket: SocketRef, _state: SioState>) { socket.on("room:leave", on_room_leave); socket.on("vote:cast", on_vote_cast); socket.on("vote:retract", on_vote_retract); + // Kept for backward compatibility; currently disabled at handler level. socket.on("vote:chase", on_vote_chase); socket.on("comment:add", on_comment_add); socket.on_disconnect(on_disconnect); @@ -50,6 +87,8 @@ async fn on_room_join( // 更新在线人数 update_online_count(&state, room_id, 1).await; + let _ = phase_tick_by_room_code(&socket, &state, room_id).await; + // 发送房间初始状态 if let Ok(room_state) = get_room_state(&state, room_id).await { let _ = socket.emit("sync:state", &room_state); @@ -89,6 +128,14 @@ async fn on_vote_cast( state: SioState>, ) { info!("[Socket.IO] Vote cast: {:?}", data); + let Some(voter_id) = authenticated_voter_id(&socket) else { + let payload = serde_json::json!({ + "reason": "unauthorized", + "fishId": data.fish_id + }); + let _ = socket.emit("vote:error", &payload); + return; + }; let fish_id = match uuid::Uuid::parse_str(&data.fish_id) { Ok(id) => id, @@ -114,12 +161,25 @@ async fn on_vote_cast( Err(_) => return, }; + let room = match phase_tick_by_room_id(&socket, &state, room.id).await { + Some(r) => r, + None => return, + }; + if room.status != "voting" { + let payload = serde_json::json!({ + "reason": "not_voting", + "fishId": data.fish_id + }); + let _ = socket.emit("vote:error", &payload); + return; + } + // 插入投票记录 let insert_result = match sqlx::query( "INSERT INTO votes (drawing_id, session_id) VALUES ($1, $2) ON CONFLICT DO NOTHING", ) .bind(fish_id) - .bind(&data.voter_id) + .bind(&voter_id) .execute(&state.db) .await { @@ -163,13 +223,13 @@ async fn on_vote_cast( // 通知被投票者 vote:received let vote_received = serde_json::json!({ "fishId": data.fish_id, - "voterId": data.voter_id + "voterId": voter_id }); let _ = socket .within(room.room_code.clone()) .emit("vote:received", &vote_received); - let elimination_threshold = room.vote_threshold(); + let elimination_threshold = room.vote_threshold(&state.config); if new_count >= elimination_threshold { // 标记为淘汰 let _ = sqlx::query( @@ -203,7 +263,7 @@ async fn on_vote_cast( } // 检查游戏结束条件 - check_game_end(&socket, &state, &room).await; + check_game_end(&socket, &state, &room, true).await; } } @@ -214,17 +274,56 @@ async fn on_vote_retract( state: SioState>, ) { info!("[Socket.IO] Vote retract: {:?}", data); + let Some(voter_id) = authenticated_voter_id(&socket) else { + let payload = serde_json::json!({ + "reason": "unauthorized", + "fishId": data.fish_id + }); + let _ = socket.emit("vote:error", &payload); + return; + }; let fish_id = match uuid::Uuid::parse_str(&data.fish_id) { Ok(id) => id, Err(_) => return, }; + let drawing = match sqlx::query_as::<_, Drawing>("SELECT * FROM drawings WHERE id = $1") + .bind(fish_id) + .fetch_optional(&state.db) + .await + { + Ok(Some(d)) => d, + _ => return, + }; + + let room = match sqlx::query_as::<_, Room>("SELECT * FROM rooms WHERE id = $1") + .bind(drawing.room_id) + .fetch_one(&state.db) + .await + { + Ok(r) => r, + Err(_) => return, + }; + + let room = match phase_tick_by_room_id(&socket, &state, room.id).await { + Some(r) => r, + None => return, + }; + if room.status != "voting" { + let payload = serde_json::json!({ + "reason": "not_voting", + "fishId": data.fish_id + }); + let _ = socket.emit("vote:error", &payload); + return; + } + // 删除投票记录 let delete_result = match sqlx::query("DELETE FROM votes WHERE drawing_id = $1 AND session_id = $2") .bind(fish_id) - .bind(&data.voter_id) + .bind(&voter_id) .execute(&state.db) .await { @@ -247,104 +346,6 @@ async fn on_vote_retract( Err(_) => return, }; - // 获取房间 - if let Ok(drawing) = sqlx::query_as::<_, Drawing>("SELECT * FROM drawings WHERE id = $1") - .bind(fish_id) - .fetch_one(&state.db) - .await - { - if let Ok(room) = sqlx::query_as::<_, Room>("SELECT * FROM rooms WHERE id = $1") - .bind(drawing.room_id) - .fetch_one(&state.db) - .await - { - // 获取剩余投票者 - let voters: Vec = - sqlx::query_scalar("SELECT session_id FROM votes WHERE drawing_id = $1") - .bind(fish_id) - .fetch_all(&state.db) - .await - .unwrap_or_default(); - - // 广播更新 - let vote_update = serde_json::json!({ - "fishId": data.fish_id, - "count": new_count, - "voters": voters - }); - let _ = socket - .within(room.room_code) - .emit("vote:update", &vote_update); - } - } -} - -/// 追击 (重复投同目标,增加票数) -async fn on_vote_chase( - socket: SocketRef, - Data(data): Data, - state: SioState>, -) { - info!("[Socket.IO] Vote chase: {:?}", data); - - // 解析 fish_id (必须是 UUID 格式,非 UUID 静默忽略) - let fish_id = match uuid::Uuid::parse_str(&data.fish_id) { - Ok(id) => id, - Err(_) => { - info!("[Socket.IO] Vote chase: invalid UUID format, ignoring"); - return; - } - }; - - // 获取 drawing (必须存在且未被淘汰) - let drawing = match sqlx::query_as::<_, Drawing>("SELECT * FROM drawings WHERE id = $1") - .bind(fish_id) - .fetch_optional(&state.db) - .await - { - Ok(Some(d)) if !d.is_eliminated => d, - Ok(Some(_)) => { - info!("[Socket.IO] Vote chase: drawing already eliminated"); - return; - } - Ok(None) => { - info!("[Socket.IO] Vote chase: drawing not found"); - return; - } - Err(e) => { - info!("[Socket.IO] Vote chase: db error: {}", e); - return; - } - }; - - // 获取 room - let room = match sqlx::query_as::<_, Room>("SELECT * FROM rooms WHERE id = $1") - .bind(drawing.room_id) - .fetch_one(&state.db) - .await - { - Ok(r) => r, - Err(e) => { - info!("[Socket.IO] Vote chase: room not found: {}", e); - return; - } - }; - - // 增加票数 (不插入投票记录,与 vote:cast 的区别) - let new_count: i32 = match sqlx::query_scalar( - "UPDATE drawings SET vote_count = vote_count + 1, updated_at = NOW() WHERE id = $1 RETURNING vote_count" - ) - .bind(fish_id) - .fetch_one(&state.db) - .await { - Ok(c) => c, - Err(e) => { - info!("[Socket.IO] Vote chase: update failed: {}", e); - return; - } - }; - - // 获取投票者列表 let voters: Vec = sqlx::query_scalar("SELECT session_id FROM votes WHERE drawing_id = $1") .bind(fish_id) @@ -352,70 +353,41 @@ async fn on_vote_chase( .await .unwrap_or_default(); - // 广播 vote:update (不是 vote:chase) let vote_update = serde_json::json!({ "fishId": data.fish_id, "count": new_count, "voters": voters }); let _ = socket - .within(room.room_code.clone()) + .within(room.room_code) .emit("vote:update", &vote_update); +} - info!( - "[Socket.IO] Vote chase: updated count to {} for fish {}", - new_count, data.fish_id - ); - - let elimination_threshold = room.vote_threshold(); - if new_count >= elimination_threshold { - // 标记为淘汰 - let _ = sqlx::query( - "UPDATE drawings SET is_eliminated = TRUE, eliminated_at = NOW() WHERE id = $1", - ) - .bind(fish_id) - .execute(&state.db) - .await; - - // 广播 fish:eliminate - let eliminate_data = serde_json::json!({ - "fishId": data.fish_id, - "fishName": drawing.name, - "isAI": drawing.is_ai, - "fishOwnerId": drawing.session_id.unwrap_or_default(), - "killerNames": voters - }); - let _ = socket - .within(room.room_code.clone()) - .emit("fish:eliminate", &eliminate_data); - - info!( - "[Socket.IO] Vote chase: fish {} eliminated (isAI: {})", - data.fish_id, drawing.is_ai - ); - - // 更新 AI 计数 - if drawing.is_ai { - let _ = sqlx::query("UPDATE rooms SET ai_count = ai_count - 1 WHERE id = $1") - .bind(room.id) - .execute(&state.db) - .await; - } +/// 追击能力当前关闭:保留事件名以兼容旧前端,统一返回 `chase_disabled`。 +async fn on_vote_chase(socket: SocketRef, Data(data): Data) { + info!("[Socket.IO] Vote chase: {:?}", data); - // 检查游戏结束条件 - check_game_end(&socket, &state, &room).await; - } + let payload = serde_json::json!({ + "reason": "chase_disabled", + "fishId": data.fish_id + }); + let _ = socket.emit("vote:error", &payload); } /// 检查游戏结束条件 /// -/// 游戏结束条件: -/// 1. 前置检查: total_items <= 5 时不检查(还没有 AI 鱼出现) -/// 2. 失败条件: 杀了 3 条非 AI 鱼 -/// 3. 失败条件: AI 鱼数量 > 5 -/// 4. 胜利条件: AI 全灭 + 人类 >= 5 -async fn check_game_end(socket: &SocketRef, state: &AppState, room: &Room) { - // 重新查询最新的 room 数据以获取准确的 total_items +/// 游戏结束条件(基于当前存活/淘汰统计与配置比例): +/// 1. 失败:被淘汰人类达到 `human_eliminated_ratio` 推导阈值 +/// 2. 胜利:AI 全灭,且存活人类达到 `victory_human_survive_ratio` 推导阈值 +/// 3. 失败:AI 相对人类优势超过 `ai_overflow_delta` +/// +/// `broadcast_phase_update = true` 时,在写入 `rooms.status=gameover` 后立即广播 `phase:update`。 +async fn check_game_end( + socket: &SocketRef, + state: &AppState, + room: &Room, + broadcast_phase_update: bool, +) { let room = match sqlx::query_as::<_, Room>("SELECT * FROM rooms WHERE id = $1") .bind(room.id) .fetch_one(&state.db) @@ -425,23 +397,12 @@ async fn check_game_end(socket: &SocketRef, state: &AppState, room: &Room) { Err(_) => return, }; - // ============ 前置检查 ============ - // 如果总鱼数 <= 5,不允许游戏结束(还没有 AI 鱼) - const MIN_ITEMS_FOR_GAME_END: i32 = 6; // 5 条人类鱼 + 至少 1 条 AI 鱼 - if room.total_items < MIN_ITEMS_FOR_GAME_END { - tracing::info!( - "[Game] End check skipped: total_items={} < {}", - room.total_items, - MIN_ITEMS_FOR_GAME_END - ); - return; - } - // ============ 获取统计数据 ============ // 查询存活和淘汰的鱼数量 #[derive(sqlx::FromRow)] struct GameStats { ai_alive: i64, + ai_eliminated: i64, human_alive: i64, human_eliminated: i64, } @@ -449,6 +410,7 @@ async fn check_game_end(socket: &SocketRef, state: &AppState, room: &Room) { let stats: Option = sqlx::query_as( "SELECT COUNT(*) FILTER (WHERE is_ai = TRUE AND is_eliminated = FALSE) as ai_alive, + COUNT(*) FILTER (WHERE is_ai = TRUE AND is_eliminated = TRUE) as ai_eliminated, COUNT(*) FILTER (WHERE is_ai = FALSE AND is_eliminated = FALSE) as human_alive, COUNT(*) FILTER (WHERE is_ai = FALSE AND is_eliminated = TRUE) as human_eliminated FROM drawings WHERE room_id = $1", @@ -463,14 +425,22 @@ async fn check_game_end(socket: &SocketRef, state: &AppState, room: &Room) { return; }; - // ============ 失败条件 1: 杀了太多人类鱼 ============ - const MAX_HUMAN_ELIMINATED: i64 = 3; - if stats.human_eliminated >= MAX_HUMAN_ELIMINATED { + let human_total = stats.human_alive + stats.human_eliminated; + let ai_total = stats.ai_alive + stats.ai_eliminated; + if ai_total < 1 || human_total < 1 { + return; + } + + let max_human_eliminated = + game_rules::human_eliminated_limit(human_total, state.config.human_eliminated_ratio); + + if stats.human_eliminated >= max_human_eliminated { let defeat_data = serde_json::json!({ "reason": "too_many_human_killed", "humanKilled": stats.human_eliminated, "aiRemaining": stats.ai_alive, - "humanRemaining": stats.human_alive + "humanRemaining": stats.human_alive, + "humanTotal": human_total }); let _ = socket .within(room.room_code.clone()) @@ -479,6 +449,15 @@ async fn check_game_end(socket: &SocketRef, state: &AppState, room: &Room) { .bind(room.id) .execute(&state.db) .await; + if let Ok(updated_room) = sqlx::query_as::<_, Room>("SELECT * FROM rooms WHERE id = $1") + .bind(room.id) + .fetch_one(&state.db) + .await + { + if broadcast_phase_update { + emit_phase_update(socket, &updated_room); + } + } tracing::info!( "[Game] Defeat: {} humans killed in room {}", stats.human_eliminated, @@ -487,17 +466,15 @@ async fn check_game_end(socket: &SocketRef, state: &AppState, room: &Room) { return; } - // ============ 原有逻辑 ============ - const VICTORY_MIN_HUMAN: i64 = 5; - const DEFEAT_AI_COUNT: i64 = 5; - - // 胜利: AI 全灭 + 人类 >= 5 - if stats.ai_alive == 0 && stats.human_alive >= VICTORY_MIN_HUMAN { + let min_human_survive = + game_rules::min_human_survive(human_total, state.config.victory_human_survive_ratio); + if stats.ai_alive == 0 && stats.human_alive >= min_human_survive { let victory_data = serde_json::json!({ "mvpId": "", "mvpName": "Unknown", "aiRemaining": stats.ai_alive, - "humanRemaining": stats.human_alive + "humanRemaining": stats.human_alive, + "humanTotal": human_total }); let _ = socket .within(room.room_code.clone()) @@ -506,14 +483,28 @@ async fn check_game_end(socket: &SocketRef, state: &AppState, room: &Room) { .bind(room.id) .execute(&state.db) .await; + if let Ok(updated_room) = sqlx::query_as::<_, Room>("SELECT * FROM rooms WHERE id = $1") + .bind(room.id) + .fetch_one(&state.db) + .await + { + if broadcast_phase_update { + emit_phase_update(socket, &updated_room); + } + } tracing::info!("[Game] Victory in room {}", room.room_code); } - // 失败: AI > 5 - else if stats.ai_alive > DEFEAT_AI_COUNT { + // 失败: AI 过载(相对优势过大) + else if game_rules::ai_overflow( + stats.ai_alive, + stats.human_alive, + state.config.ai_overflow_delta, + ) { let defeat_data = serde_json::json!({ "reason": "ai_overrun", "aiRemaining": stats.ai_alive, - "humanRemaining": stats.human_alive + "humanRemaining": stats.human_alive, + "humanTotal": human_total }); let _ = socket .within(room.room_code.clone()) @@ -522,6 +513,15 @@ async fn check_game_end(socket: &SocketRef, state: &AppState, room: &Room) { .bind(room.id) .execute(&state.db) .await; + if let Ok(updated_room) = sqlx::query_as::<_, Room>("SELECT * FROM rooms WHERE id = $1") + .bind(room.id) + .fetch_one(&state.db) + .await + { + if broadcast_phase_update { + emit_phase_update(socket, &updated_room); + } + } tracing::info!("[Game] Defeat: AI overrun in room {}", room.room_code); } } @@ -541,6 +541,15 @@ async fn on_comment_add(socket: SocketRef, Data(data): Data) { // === Helper Types === +#[derive(Debug, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct ConnectAuthData { + #[serde(default)] + token: Option, + #[serde(default)] + authorization: Option, +} + #[derive(Debug, serde::Deserialize)] #[serde(rename_all = "camelCase")] struct RoomJoinData { @@ -557,7 +566,6 @@ struct RoomLeaveData { #[serde(rename_all = "camelCase")] struct BattleVoteCastData { fish_id: String, - voter_id: String, } #[derive(Debug, serde::Deserialize, serde::Serialize)] @@ -574,6 +582,76 @@ struct CommentData { content: String, } +fn extract_auth_token(auth: &ConnectAuthData) -> Option { + if let Some(token) = auth + .token + .as_deref() + .map(str::trim) + .filter(|s| !s.is_empty()) + { + return Some(token.to_string()); + } + + let authorization = auth + .authorization + .as_deref() + .map(str::trim) + .filter(|s| !s.is_empty())?; + + if authorization.len() >= 7 && authorization[..7].eq_ignore_ascii_case("Bearer ") { + let bearer = authorization[7..].trim(); + if bearer.is_empty() { + None + } else { + Some(bearer.to_string()) + } + } else { + Some(authorization.to_string()) + } +} + +fn authenticated_voter_id(socket: &SocketRef) -> Option { + socket + .extensions + .get::() + .map(|session| session.user_id.to_string()) +} + +#[cfg(test)] +mod tests { + use super::{extract_auth_token, ConnectAuthData}; + + #[test] + fn extract_auth_token_prefers_token_field() { + let payload = ConnectAuthData { + token: Some("token-123".to_string()), + authorization: Some("Bearer ignored".to_string()), + }; + + assert_eq!(extract_auth_token(&payload), Some("token-123".to_string())); + } + + #[test] + fn extract_auth_token_supports_bearer_authorization() { + let payload = ConnectAuthData { + token: None, + authorization: Some("Bearer token-xyz".to_string()), + }; + + assert_eq!(extract_auth_token(&payload), Some("token-xyz".to_string())); + } + + #[test] + fn extract_auth_token_returns_none_when_missing() { + let payload = ConnectAuthData { + token: None, + authorization: None, + }; + + assert_eq!(extract_auth_token(&payload), None); + } +} + // === Helper Functions === /// 获取房间初始状态 (sync:state) @@ -635,11 +713,187 @@ async fn get_room_state(state: &AppState, room_code: &str) -> Result, + #[serde(skip_serializing_if = "Option::is_none")] + voting_ends_at: Option, + server_time: i64, +} + +fn emit_phase_update(socket: &SocketRef, room: &Room) { + let payload = PhaseUpdateData { + phase: room.status.clone(), + room_id: room.room_code.clone(), + voting_started_at: room.voting_started_at.map(|t| t.timestamp_millis()), + voting_ends_at: room.voting_ends_at.map(|t| t.timestamp_millis()), + server_time: Utc::now().timestamp_millis(), + }; + let _ = socket + .within(room.room_code.clone()) + .emit("phase:update", &payload); +} + +async fn phase_tick_by_room_code( + socket: &SocketRef, + state: &AppState, + room_code: &str, +) -> Option { + let room: Room = sqlx::query_as("SELECT * FROM rooms WHERE room_code = $1") + .bind(room_code) + .fetch_optional(&state.db) + .await + .ok() + .flatten()?; + phase_tick(socket, state, room).await +} + +async fn phase_tick_by_room_id( + socket: &SocketRef, + state: &AppState, + room_id: uuid::Uuid, +) -> Option { + let room: Room = sqlx::query_as("SELECT * FROM rooms WHERE id = $1") + .bind(room_id) + .fetch_optional(&state.db) + .await + .ok() + .flatten()?; + phase_tick(socket, state, room).await +} + +async fn phase_tick(socket: &SocketRef, state: &AppState, mut room: Room) -> Option { + let before_phase = room.status.clone(); + + if room.status == "voting" { + let expired = room + .voting_ends_at + .map(|t| Utc::now() >= t) + .unwrap_or(false); + if expired { + check_game_end(socket, state, &room, false).await; + let refreshed: Room = sqlx::query_as("SELECT * FROM rooms WHERE id = $1") + .bind(room.id) + .fetch_one(&state.db) + .await + .ok()?; + if refreshed.status != "gameover" { + reset_votes_and_exit_voting(state, refreshed.id).await; + room = sqlx::query_as("SELECT * FROM rooms WHERE id = $1") + .bind(refreshed.id) + .fetch_one(&state.db) + .await + .ok()?; + } else { + room = refreshed; + } + } + } + + if room.status == "active" { + #[derive(sqlx::FromRow)] + struct AliveStats { + ai_alive: i64, + human_alive: i64, + } + + let stats: AliveStats = sqlx::query_as( + "SELECT + COUNT(*) FILTER (WHERE is_ai = TRUE AND is_hidden = FALSE AND is_eliminated = FALSE) as ai_alive, + COUNT(*) FILTER (WHERE is_ai = FALSE AND is_hidden = FALSE AND is_eliminated = FALSE) as human_alive + FROM drawings WHERE room_id = $1", + ) + .bind(room.id) + .fetch_one(&state.db) + .await + .ok()?; + + let min_humans = state.config.min_humans_to_start_voting.max(1); + let submit_deadline = + room.created_at + Duration::seconds(state.config.submit_duration_seconds.max(10)); + let submit_time_up = Utc::now() >= submit_deadline; + let should_start = (stats.ai_alive >= 1 && stats.human_alive >= min_humans) + || (submit_time_up && stats.ai_alive >= 1 && stats.human_alive >= 1); + if should_start { + room = start_voting(state, room.id).await.unwrap_or(room); + } + } + + if before_phase != room.status { + emit_phase_update(socket, &room); + } + + Some(room) +} + +async fn start_voting(state: &AppState, room_id: uuid::Uuid) -> Option { + let seconds = state.config.voting_duration_seconds.max(5); + let updated: Room = sqlx::query_as( + "UPDATE rooms + SET status = 'voting', + voting_started_at = NOW(), + voting_ends_at = NOW() + ($2 || ' seconds')::interval, + updated_at = NOW() + WHERE id = $1 AND status = 'active' AND voting_started_at IS NULL + RETURNING *", + ) + .bind(room_id) + .bind(seconds) + .fetch_optional(&state.db) + .await + .ok() + .flatten()?; + Some(updated) +} + +async fn reset_votes_and_exit_voting(state: &AppState, room_id: uuid::Uuid) { + let mut tx = match state.db.begin().await { + Ok(tx) => tx, + Err(_) => return, + }; + + let _ = sqlx::query( + "DELETE FROM votes v + USING drawings d + WHERE v.drawing_id = d.id AND d.room_id = $1", + ) + .bind(room_id) + .execute(&mut *tx) + .await; + + let _ = + sqlx::query("UPDATE drawings SET vote_count = 0, updated_at = NOW() WHERE room_id = $1") + .bind(room_id) + .execute(&mut *tx) + .await; + + let _ = sqlx::query( + "UPDATE rooms + SET status = 'active', + voting_started_at = NULL, + voting_ends_at = NULL, + updated_at = NOW() + WHERE id = $1 AND status = 'voting'", + ) + .bind(room_id) + .execute(&mut *tx) + .await; + + let _ = tx.commit().await; +} + /// 更新在线人数 async fn update_online_count(state: &AppState, room_code: &str, delta: i32) { if let Ok(Some(room)) = sqlx::query_as::<_, Room>("SELECT * FROM rooms WHERE room_code = $1") @@ -666,6 +920,11 @@ struct SyncStateData { total_items: i32, ai_count: i32, turbidity: f64, + #[serde(skip_serializing_if = "Option::is_none")] + voting_started_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + voting_ends_at: Option, + server_time: i64, theme: ThemeResponse, items: Vec, } diff --git a/frontend/src/app/game/page.tsx b/frontend/src/app/game/page.tsx index f6a567a..32e8bbc 100644 --- a/frontend/src/app/game/page.tsx +++ b/frontend/src/app/game/page.tsx @@ -12,6 +12,7 @@ import SubmitForm from '@/components/ui/SubmitForm' import { ItemDetailModal, VotingTimer } from '@/components/voting/ItemDetailModal' import { createDrawing, + ensureGuestAuth, voteDrawing, getOrCreateSessionId, } from '@/lib/api' @@ -82,6 +83,7 @@ export default function GamePage() { const [showItemModal, setShowItemModal] = useState(false) const [submitting, setSubmitting] = useState(false) const [sessionId, setSessionId] = useState('') + const [authToken, setAuthToken] = useState('') const [isExporting, setIsExporting] = useState(false) // 导出图片 loading 状态 // 使用固定初始值避免 SSR/CSR hydration 不匹配 const [loadingTip, setLoadingTip] = useState(LOADING_TIPS[0]) // 加载提示轮播 @@ -100,11 +102,30 @@ export default function GamePage() { return () => clearInterval(interval) }, [isSynced]) - // 初始化 session ID + // 初始化 session ID + 游客身份 useEffect(() => { - const id = getOrCreateSessionId() - setSessionId(id) - setPlayerId(id) // 设置玩家 ID + let cancelled = false + + const initIdentity = async () => { + const id = getOrCreateSessionId() + if (cancelled) return + setSessionId(id) + + try { + const auth = await ensureGuestAuth(id) + if (cancelled) return + setAuthToken(auth.token) + setPlayerId(auth.userId) + } catch (err) { + console.error('Failed to initialize guest auth:', err) + } + } + + initIdentity() + + return () => { + cancelled = true + } }, [setPlayerId]) // 刷新页面时,如果没有 roomId,重定向到首页 @@ -117,7 +138,8 @@ export default function GamePage() { // 连接 WebSocket const { submitComment, emit, battleVote, retractVote, chaseVote } = useWebSocket({ roomId: roomId || '', - enabled: !!roomId, + authToken, + enabled: !!roomId && !!authToken, }) // 战斗系统 diff --git a/frontend/src/hooks/useBattleSystem.ts b/frontend/src/hooks/useBattleSystem.ts index 6f8523e..9944d20 100644 --- a/frontend/src/hooks/useBattleSystem.ts +++ b/frontend/src/hooks/useBattleSystem.ts @@ -21,7 +21,6 @@ export function useBattleSystem({ emit }: UseBattleSystemOptions = {}) { const { bullet, fishVotes, - playerId, playerFishId, items, aiCount, @@ -93,12 +92,12 @@ export function useBattleSystem({ emit }: UseBattleSystemOptions = {}) { // 发送投票事件到后端 if (emit) { - emit('vote:cast', { fishId, voterId: playerId }) + emit('vote:cast', { fishId }) } return true }, - [canVote, fireBullet, addFloatingDamage, emit, playerId] + [canVote, fireBullet, addFloatingDamage, emit] ) // 换目标 @@ -120,13 +119,13 @@ export function useBattleSystem({ emit }: UseBattleSystemOptions = {}) { // 发送撤票 + 新投票事件 if (emit) { - emit('vote:retract', { fishId: oldTarget, voterId: playerId }) - emit('vote:cast', { fishId: newFishId, voterId: playerId }) + emit('vote:retract', { fishId: oldTarget }) + emit('vote:cast', { fishId: newFishId }) } return true }, - [bullet.currentTarget, changeTarget, addFloatingDamage, emit, playerId] + [bullet.currentTarget, changeTarget, addFloatingDamage, emit] ) // 追击 @@ -141,12 +140,12 @@ export function useBattleSystem({ emit }: UseBattleSystemOptions = {}) { // 发送追击事件 if (emit) { - emit('vote:chase', { fishId, voterId: playerId }) + emit('vote:chase', { fishId }) } return true }, - [canChase, chaseFire, emit, playerId] + [canChase, chaseFire, emit] ) // 获取鱼的当前票数 diff --git a/frontend/src/hooks/useWebSocket.ts b/frontend/src/hooks/useWebSocket.ts index 130e39c..884fb05 100644 --- a/frontend/src/hooks/useWebSocket.ts +++ b/frontend/src/hooks/useWebSocket.ts @@ -18,6 +18,7 @@ import { ENV_CONFIG } from '@/config/env' interface UseWebSocketOptions { url?: string roomId: string + authToken?: string enabled?: boolean } @@ -119,7 +120,7 @@ function convertBackendItem(item: BackendGameItem): GameItem { } } -export function useWebSocket({ url, roomId, enabled = true }: UseWebSocketOptions) { +export function useWebSocket({ url, roomId, authToken, enabled = true }: UseWebSocketOptions) { const socketRef = useRef(null) const { addItem, @@ -147,7 +148,7 @@ export function useWebSocket({ url, roomId, enabled = true }: UseWebSocketOption // 连接 WebSocket useEffect(() => { - if (!enabled || !roomId) return + if (!enabled || !roomId || !authToken) return // 重置同步状态 setIsSynced(false) @@ -160,6 +161,7 @@ export function useWebSocket({ url, roomId, enabled = true }: UseWebSocketOption reconnection: true, reconnectionDelay: 1000, reconnectionAttempts: 5, + auth: { token: authToken }, }) const socket = socketRef.current @@ -330,6 +332,7 @@ export function useWebSocket({ url, roomId, enabled = true }: UseWebSocketOption }, [ url, roomId, + authToken, enabled, addItem, removeItem, @@ -391,24 +394,24 @@ export function useWebSocket({ url, roomId, enabled = true }: UseWebSocketOption // 投票 const battleVote = useCallback( - (fishId: string, voterId: string) => { - emit('vote:cast', { fishId, voterId }) + (fishId: string) => { + emit('vote:cast', { fishId }) }, [emit] ) // 撤票 const retractVote = useCallback( - (fishId: string, voterId: string) => { - emit('vote:retract', { fishId, voterId }) + (fishId: string) => { + emit('vote:retract', { fishId }) }, [emit] ) // 追击 const chaseVote = useCallback( - (fishId: string, voterId: string) => { - emit('vote:chase', { fishId, voterId }) + (fishId: string) => { + emit('vote:chase', { fishId }) }, [emit] ) diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index 9cf2a67..1146443 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -6,6 +6,9 @@ import { ENV_CONFIG } from '@/config/env' const API_BASE = ENV_CONFIG.API_URL +const GUEST_DEVICE_TOKEN_STORAGE_KEY = 'mimic_guest_device_token' +const AUTH_TOKEN_STORAGE_KEY = 'mimic_auth_token' +const AUTH_USER_ID_STORAGE_KEY = 'mimic_auth_user_id' // ============ 类型定义 ============ @@ -83,6 +86,14 @@ export interface ReportResponse { hidden: boolean } +export interface GuestLoginResponse { + token: string + userId: string + isNewUser: boolean + isGuest: boolean + deviceToken: string +} + // ============ API 函数 ============ /** @@ -214,6 +225,16 @@ export async function reportDrawing( }) } +export async function guestLogin(payload: { + deviceToken?: string + sessionId?: string +}): Promise { + return request('/api/auth/guest/login', { + method: 'POST', + body: JSON.stringify(payload), + }) +} + // ============ 健康检查 ============ /** @@ -251,6 +272,48 @@ export function getOrCreateSessionId(): string { return sessionId } +export function getGuestDeviceToken(): string | null { + if (typeof window === 'undefined') return null + return localStorage.getItem(GUEST_DEVICE_TOKEN_STORAGE_KEY) +} + +export function setGuestDeviceToken(deviceToken: string) { + if (typeof window === 'undefined') return + localStorage.setItem(GUEST_DEVICE_TOKEN_STORAGE_KEY, deviceToken) +} + +export function getAuthToken(): string | null { + if (typeof window === 'undefined') return null + return localStorage.getItem(AUTH_TOKEN_STORAGE_KEY) +} + +export function getAuthUserId(): string | null { + if (typeof window === 'undefined') return null + return localStorage.getItem(AUTH_USER_ID_STORAGE_KEY) +} + +export function setAuthSession(token: string, userId: string) { + if (typeof window === 'undefined') return + localStorage.setItem(AUTH_TOKEN_STORAGE_KEY, token) + localStorage.setItem(AUTH_USER_ID_STORAGE_KEY, userId) +} + +export async function ensureGuestAuth( + sessionId?: string +): Promise<{ token: string; userId: string }> { + const cachedToken = getAuthToken() + const cachedUserId = getAuthUserId() + if (cachedToken && cachedUserId) { + return { token: cachedToken, userId: cachedUserId } + } + + const deviceToken = getGuestDeviceToken() || undefined + const resp = await guestLogin({ deviceToken, sessionId }) + setGuestDeviceToken(resp.deviceToken) + setAuthSession(resp.token, resp.userId) + return { token: resp.token, userId: resp.userId } +} + /** * 转换后端主题格式为前端格式 */ diff --git a/frontend/src/types/battle.ts b/frontend/src/types/battle.ts index 6467bc8..e6fee23 100644 --- a/frontend/src/types/battle.ts +++ b/frontend/src/types/battle.ts @@ -87,17 +87,14 @@ export type BattleWSEventType = // 投票事件数据 (前端 → 后端) export interface VoteCastData { fishId: string - voterId: string } export interface VoteRetractData { fishId: string - voterId: string } export interface VoteChaseData { fishId: string - voterId: string } // 投票更新事件数据 (后端 → 前端) diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 60084b6..5dfbb57 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -80,7 +80,6 @@ export interface DrawingState { // 投票 export interface VotePayload { itemId: string - voterId: string } // WebSocket 事件 diff --git a/n8n-setup.md b/n8n-setup.md index dde2c50..10dc8a0 100644 --- a/n8n-setup.md +++ b/n8n-setup.md @@ -6,9 +6,10 @@ 原 Zenmux API 已不可用,系统目前使用预置鱼图片。 如需启用 AI 生成,请: -1. 配置新的图像生成 API -2. 更新 n8n workflow 中的 API endpoint +1. 配置可用的图像生成 API(本仓库 n8n workflow 默认使用 Zenmux OpenAI-compatible images endpoint) +2. 在 n8n 中导入并激活工作流 3. 设置环境变量 `AI_GENERATION_ENABLED=true` +4. 配置模型与尺寸(默认优先使用 nano banana pro) ## 导入工作流 @@ -29,12 +30,35 @@ > **注意**: 需要先获取有效的 API Token。请根据选择的 API 提供商获取凭证。 +## 模型与速度/质量权衡(工作流参数) + +该工作流包含 fast → quality → fallback 的分层重试策略,并会**优先选择 nano banana pro**(可通过环境变量覆盖)。 + +建议在 n8n 容器环境变量中配置(或通过 n8n 的环境变量注入方式): + +``` +ZENMUX_MODEL_FAST=nano banana pro +ZENMUX_MODEL_QUALITY=nano banana pro +ZENMUX_MODEL_FALLBACK=google/gemini-3-pro-image-preview + +ZENMUX_SIZE_FAST=256x256 +ZENMUX_SIZE_QUALITY=512x512 +ZENMUX_SIZE_FALLBACK=512x512 +``` + +说明: +- fast:追求更快返回(更小尺寸、较短 timeout),用于“先生成能用的图” +- quality:更高分辨率与更长 timeout,失败时自动降级/回退 +- fallback:当主模型/质量模型失败时启用的兜底模型 + ## 激活工作流 1. 打开导入的工作流 2. 点击右上角 **Activate** 开关 3. Webhook URL 将变为: `http://localhost:5678/webhook/mimic-ai-generate` +> 重要:后端触发 webhook 的 HTTP timeout 为 10 秒。该工作流已调整为“先快速响应,再后台生成并 callback”,避免后端因等待图片生成而超时。 + ## 验证 后端 `.env` 中需要配置: diff --git a/n8n-workflow.json b/n8n-workflow.json index 294e3f0..2d06d42 100644 --- a/n8n-workflow.json +++ b/n8n-workflow.json @@ -18,6 +18,62 @@ ], "webhookId": "mimic-ai-generate" }, + { + "parameters": { + "jsCode": "const input = $json;\n\nconst theme = input.theme || {};\nconst keywords = Array.isArray(theme.keywords) ? theme.keywords : [];\nconst palette = Array.isArray(theme.palette) ? theme.palette : [];\nconst promptStyle = typeof theme.prompt_style === 'string' ? theme.prompt_style : 'stylized illustration';\n\nconst keyword = keywords.length ? keywords[Math.floor(Math.random() * keywords.length)] : 'fish';\nconst colors = palette.slice(0, 3).join(', ');\n\nconst modelFast = process.env.ZENMUX_MODEL_FAST || 'nano banana pro';\nconst modelQuality = process.env.ZENMUX_MODEL_QUALITY || modelFast;\nconst modelFallback = process.env.ZENMUX_MODEL_FALLBACK || 'google/gemini-3-pro-image-preview';\n\nconst sizeFast = process.env.ZENMUX_SIZE_FAST || '256x256';\nconst sizeQuality = process.env.ZENMUX_SIZE_QUALITY || '512x512';\nconst sizeFallback = process.env.ZENMUX_SIZE_FALLBACK || sizeQuality;\n\nconst promptFast = `A ${promptStyle} of a ${keyword}. Simple clean background. High contrast, crisp silhouette. Colors: ${colors}. No text, no watermark, no logo.`;\nconst promptQuality = `A ${promptStyle} of a ${keyword}. Highly detailed but clean silhouette, crisp edges, soft shading. Simple background. Colors: ${colors}. No text, no watermark, no logo.`;\n\nconst name = keyword;\nconst description = `${promptStyle}`.slice(0, 60);\n\nreturn [{\n json: {\n ...input,\n keyword,\n name,\n description,\n model_fast: modelFast,\n model_quality: modelQuality,\n model_fallback: modelFallback,\n size_fast: sizeFast,\n size_quality: sizeQuality,\n size_fallback: sizeFallback,\n prompt_fast: promptFast,\n prompt_quality: promptQuality,\n }\n}];\n" + }, + "id": "prepare", + "name": "Prepare", + "type": "n8n-nodes-base.code", + "typeVersion": 2, + "position": [ + 420, + 300 + ] + }, + { + "parameters": { + "respondWith": "json", + "responseBody": "={{ JSON.stringify({ \"received\": true, \"task_id\": $('Webhook').item.json.task_id }) }}" + }, + "id": "respond-webhook", + "name": "Respond to Webhook", + "type": "n8n-nodes-base.respondToWebhook", + "typeVersion": 1.1, + "position": [ + 640, + 300 + ] + }, + { + "parameters": { + "method": "POST", + "url": "https://api.zenmux.ai/v1/images/generations", + "authentication": "predefinedCredentialType", + "nodeCredentialType": "httpBearerAuth", + "sendBody": true, + "specifyBody": "json", + "jsonBody": "={{ JSON.stringify({\n \"model\": $('Prepare').item.json.model_fast,\n \"prompt\": $('Prepare').item.json.prompt_fast,\n \"n\": 1,\n \"size\": $('Prepare').item.json.size_fast,\n \"response_format\": \"b64_json\"\n}) }}", + "options": { + "timeout": 30000 + } + }, + "id": "generate-fast", + "name": "Generate Fast", + "type": "n8n-nodes-base.httpRequest", + "typeVersion": 4.2, + "position": [ + 640, + 160 + ], + "credentials": { + "httpBearerAuth": { + "id": "zenmux-api", + "name": "Zenmux API" + } + }, + "onError": "continueErrorOutput" + }, { "parameters": { "method": "POST", @@ -26,18 +82,18 @@ "nodeCredentialType": "httpBearerAuth", "sendBody": true, "specifyBody": "json", - "jsonBody": "={{ JSON.stringify({\n \"model\": \"google/gemini-3-pro-image-preview\",\n \"prompt\": `A ${$json.theme.prompt_style} of a ${$json.theme.keywords[Math.floor(Math.random() * $json.theme.keywords.length)]}, colors: ${$json.theme.palette.slice(0, 3).join(', ')}`,\n \"n\": 1,\n \"size\": \"256x256\",\n \"response_format\": \"b64_json\"\n}) }}", + "jsonBody": "={{ JSON.stringify({\n \"model\": $('Prepare').item.json.model_quality,\n \"prompt\": $('Prepare').item.json.prompt_quality,\n \"n\": 1,\n \"size\": $('Prepare').item.json.size_quality,\n \"response_format\": \"b64_json\"\n}) }}", "options": { "timeout": 60000 } }, - "id": "generate-image", - "name": "Generate AI Image", + "id": "generate-quality", + "name": "Generate Quality", "type": "n8n-nodes-base.httpRequest", "typeVersion": 4.2, "position": [ - 440, - 300 + 840, + 160 ], "credentials": { "httpBearerAuth": { @@ -50,51 +106,66 @@ { "parameters": { "method": "POST", - "url": "={{ $('Webhook').item.json.callback_url }}", + "url": "https://api.zenmux.ai/v1/images/generations", + "authentication": "predefinedCredentialType", + "nodeCredentialType": "httpBearerAuth", "sendBody": true, "specifyBody": "json", - "jsonBody": "={{ JSON.stringify({\n \"task_id\": $('Webhook').item.json.task_id,\n \"status\": \"completed\",\n \"image_data\": 'data:image/png;base64,' + $json.data[0].b64_json,\n \"name\": $('Webhook').item.json.theme.keywords[Math.floor(Math.random() * $('Webhook').item.json.theme.keywords.length)]\n}) }}", - "options": {} + "jsonBody": "={{ JSON.stringify({\n \"model\": $('Prepare').item.json.model_fallback,\n \"prompt\": $('Prepare').item.json.prompt_quality,\n \"n\": 1,\n \"size\": $('Prepare').item.json.size_fallback,\n \"response_format\": \"b64_json\"\n}) }}", + "options": { + "timeout": 90000 + } }, - "id": "callback-success", - "name": "Callback Success", + "id": "generate-fallback", + "name": "Generate Fallback", "type": "n8n-nodes-base.httpRequest", "typeVersion": 4.2, "position": [ - 660, - 220 - ] + 1040, + 160 + ], + "credentials": { + "httpBearerAuth": { + "id": "zenmux-api", + "name": "Zenmux API" + } + }, + "onError": "continueErrorOutput" }, { "parameters": { "method": "POST", - "url": "={{ $('Webhook').item.json.callback_url }}", + "url": "={{ $('Prepare').item.json.callback_url }}", "sendBody": true, "specifyBody": "json", - "jsonBody": "={{ JSON.stringify({\n \"task_id\": $('Webhook').item.json.task_id,\n \"status\": \"failed\",\n \"error_message\": $json.error?.message || 'Unknown error'\n}) }}", + "jsonBody": "={{ JSON.stringify({\n \"task_id\": $('Prepare').item.json.task_id,\n \"status\": \"completed\",\n \"image_data\": 'data:image/png;base64,' + $json.data[0].b64_json,\n \"name\": $('Prepare').item.json.name,\n \"description\": $('Prepare').item.json.description\n}) }}", "options": {} }, - "id": "callback-error", - "name": "Callback Error", + "id": "callback-success", + "name": "Callback Success", "type": "n8n-nodes-base.httpRequest", "typeVersion": 4.2, "position": [ - 660, - 400 + 1240, + 120 ] }, { "parameters": { - "respondWith": "json", - "responseBody": "={{ JSON.stringify({ \"received\": true, \"task_id\": $('Webhook').item.json.task_id }) }}" + "method": "POST", + "url": "={{ $('Prepare').item.json.callback_url }}", + "sendBody": true, + "specifyBody": "json", + "jsonBody": "={{ JSON.stringify({\n \"task_id\": $('Prepare').item.json.task_id,\n \"status\": \"failed\",\n \"error_message\": $json.error?.message || 'Unknown error'\n}) }}", + "options": {} }, - "id": "respond-webhook", - "name": "Respond to Webhook", - "type": "n8n-nodes-base.respondToWebhook", - "typeVersion": 1.1, + "id": "callback-error", + "name": "Callback Error", + "type": "n8n-nodes-base.httpRequest", + "typeVersion": 4.2, "position": [ - 880, - 300 + 1240, + 240 ] } ], @@ -103,14 +174,30 @@ "main": [ [ { - "node": "Generate AI Image", + "node": "Prepare", + "type": "main", + "index": 0 + } + ] + ] + }, + "Prepare": { + "main": [ + [ + { + "node": "Respond to Webhook", + "type": "main", + "index": 0 + }, + { + "node": "Generate Fast", "type": "main", "index": 0 } ] ] }, - "Generate AI Image": { + "Generate Fast": { "main": [ [ { @@ -121,29 +208,43 @@ ], [ { - "node": "Callback Error", + "node": "Generate Quality", "type": "main", "index": 0 } ] ] }, - "Callback Success": { + "Generate Quality": { "main": [ [ { - "node": "Respond to Webhook", + "node": "Callback Success", + "type": "main", + "index": 0 + } + ], + [ + { + "node": "Generate Fallback", "type": "main", "index": 0 } ] ] }, - "Callback Error": { + "Generate Fallback": { "main": [ [ { - "node": "Respond to Webhook", + "node": "Callback Success", + "type": "main", + "index": 0 + } + ], + [ + { + "node": "Callback Error", "type": "main", "index": 0 } @@ -158,4 +259,4 @@ "tags": [], "triggerCount": 0, "pinData": {} -} \ No newline at end of file +}