diff --git a/app/api/websocket_api.py b/app/api/websocket_api.py index 81ec72f9..0ef62be9 100644 --- a/app/api/websocket_api.py +++ b/app/api/websocket_api.py @@ -6,6 +6,7 @@ from flask import Blueprint, request, jsonify from app.services.websocket_push_service import push_service from app.websocket.websocket_events import get_connection_stats +from datetime import datetime import logging logger = logging.getLogger(__name__) @@ -192,7 +193,7 @@ def test_connection(): test_data = { 'type': 'test', 'message': 'WebSocket连接测试', - 'timestamp': push_service.get_push_status()['connection_stats'] + 'timestamp': datetime.now().isoformat() } push_service.trigger_immediate_push('monitor', test_data) diff --git a/app/services/websocket_push_service.py b/app/services/websocket_push_service.py index cad88aa6..bbc97fde 100644 --- a/app/services/websocket_push_service.py +++ b/app/services/websocket_push_service.py @@ -9,6 +9,7 @@ from datetime import datetime, timedelta from typing import Dict, List, Any, Optional import json +import pandas as pd from app.extensions import db from app.models.realtime_indicator import RealtimeIndicator @@ -43,7 +44,6 @@ def __init__(self): self.event_store = ParquetEventStore() self.is_running = False - self.push_thread = None self.push_interval = 30 # 推送间隔(秒) # 推送配置 @@ -65,43 +65,42 @@ def start_push_service(self): if self.is_running: logger.warning("推送服务已在运行") return - + self.is_running = True - self.push_thread = threading.Thread(target=self._push_loop, daemon=True) - self.push_thread.start() + from app.extensions import socketio + socketio.start_background_task(target=self._push_loop) logger.info("WebSocket推送服务已启动") - + def stop_push_service(self): """停止推送服务""" self.is_running = False - if self.push_thread: - self.push_thread.join(timeout=5) logger.info("WebSocket推送服务已停止") def _push_loop(self): """推送循环""" + from app.extensions import socketio as _sio while self.is_running: try: current_time = datetime.now() - + # 检查各类数据是否需要推送 for data_type, config in self.push_config.items(): if not config['enabled']: continue - + last_push = self.last_push_times.get(data_type) - if (not last_push or + if (not last_push or (current_time - last_push).total_seconds() >= config['interval']): - + self._push_data_type(data_type) self.last_push_times[data_type] = current_time - + # 等待下一次检查 - time.sleep(10) # 每10秒检查一次 - + _sio.sleep(10) # 使用 socketio.sleep 保证 eventlet 兼容 + except Exception as e: logger.error(f"推送循环错误: {e}") - time.sleep(30) # 出错后等待30秒再继续 + _sio.sleep(30) def _push_data_type(self, data_type: str): """推送指定类型的数据""" @@ -127,32 +126,40 @@ def _push_data_type(self, data_type: str): def _push_market_data(self): """推送市场数据""" try: - # 获取活跃股票列表 - active_stocks = self.data_manager.get_active_stocks() - - for stock in active_stocks[:20]: # 限制推送数量 - ts_code = stock['ts_code'] - - # 获取最新数据 - latest_data = self.data_manager.get_latest_data(ts_code, '1min', 1) - if latest_data: - market_data = { - 'ts_code': ts_code, - 'datetime': latest_data[0]['datetime'], - 'open': latest_data[0]['open'], - 'high': latest_data[0]['high'], - 'low': latest_data[0]['low'], - 'close': latest_data[0]['close'], - 'volume': latest_data[0]['volume'], - 'amount': latest_data[0]['amount'], - 'change_pct': self._calculate_change_pct(latest_data[0]) - } - - broadcast_market_data(ts_code, market_data) - broadcast_market_data('all', market_data) # 广播到全局房间 - - logger.debug(f"推送市场数据完成,股票数量: {len(active_stocks)}") - + # 获取有分钟数据的股票列表 + active_stocks = self.data_manager.get_available_minute_stocks() + + pushed_count = 0 + for ts_code in active_stocks[:20]: # 限制推送数量 + # 尝试各周期,优先1min,fallback到更粗粒度 + latest_data = pd.DataFrame() + for period in ['1min', '5min', '15min', '30min', '60min']: + latest_data = self.data_manager.get_minute_latest_data(ts_code, period, 2) + if not latest_data.empty: + break + + if latest_data.empty: + continue + + row = latest_data.iloc[0] + market_data = { + 'ts_code': ts_code, + 'datetime': str(row.get('datetime', '')), + 'open': float(row.get('open', 0)), + 'high': float(row.get('high', 0)), + 'low': float(row.get('low', 0)), + 'close': float(row.get('close', 0)), + 'volume': float(row.get('volume', 0)), + 'amount': float(row.get('amount', 0)), + 'change_pct': self._calculate_change_pct(latest_data) + } + + broadcast_market_data(ts_code, market_data) + broadcast_market_data('all', market_data) # 广播到全局房间 + pushed_count += 1 + + logger.info(f"推送市场数据完成,股票数量: {pushed_count}/{len(active_stocks)}") + except Exception as e: logger.error(f"推送市场数据失败: {e}") @@ -240,18 +247,22 @@ def _push_monitor_data(self): """推送监控数据""" try: # 获取监控数据 + anomaly_result = self.monitor_service.detect_anomalies( + change_threshold=5.0, volume_threshold=3.0 + ) + anomaly_list = anomaly_result.get('data', {}).get('anomalies', []) \ + if isinstance(anomaly_result, dict) and anomaly_result.get('success') else [] + monitor_data = { - 'market_overview': self.monitor_service.get_market_overview(), - 'top_movers': self.monitor_service.get_top_movers(limit=10), - 'anomalies': self.monitor_service.detect_anomalies( - change_threshold=5.0, volume_threshold=3.0 - ), - 'sentiment': self.monitor_service.calculate_market_sentiment(period_hours=1) + 'market_overview': self.data_manager.get_market_overview(), + 'top_movers': anomaly_list, + 'anomalies': anomaly_list, + 'sentiment': self.monitor_service.get_market_sentiment(period_hours=1) } - + broadcast_monitor_data(monitor_data) - logger.debug("推送监控数据完成") - + logger.info("推送监控数据完成") + except Exception as e: logger.error(f"推送监控数据失败: {e}") @@ -331,17 +342,22 @@ def _get_news_payload(self) -> List[Dict[str, Any]]: """获取可推送的新闻数据。默认不生成模拟新闻。""" return [] - def _calculate_change_pct(self, current_data: Dict) -> float: - """计算涨跌幅""" + def _calculate_change_pct(self, latest_data) -> float: + """计算涨跌幅,传入最近2条DataFrame记录""" try: - # 获取前一个交易日收盘价(简化处理) - prev_close = current_data.get('open', current_data['close']) - current_close = current_data['close'] - + if latest_data.empty: + return 0.0 + current_close = float(latest_data.iloc[0].get('close', 0)) + if len(latest_data) >= 2: + prev_close = float(latest_data.iloc[1].get('close', 0)) + else: + # 只有1条数据时用开盘价作为近似 + prev_close = float(latest_data.iloc[0].get('open', current_close)) + if prev_close and prev_close != 0: return round(((current_close - prev_close) / prev_close) * 100, 2) return 0.0 - + except Exception: return 0.0 diff --git a/app/templates/realtime_analysis/websocket_management.html b/app/templates/realtime_analysis/websocket_management.html index 147fc2e3..d610bb9e 100644 --- a/app/templates/realtime_analysis/websocket_management.html +++ b/app/templates/realtime_analysis/websocket_management.html @@ -257,7 +257,7 @@
消息日志
} try { - socket = io('http://127.0.0.1:5001', { + socket = io(window.location.origin, { transports: ['websocket', 'polling'] }); @@ -656,7 +656,7 @@
${getTypeName(type)}
${timestamp}
- ${JSON.stringify(data, null, 2).substring(0, 200)}... + ${JSON.stringify(data, null, 2).substring(0, 200)}${JSON.stringify(data).length > 200 ? '...' : ''}
`; diff --git a/tests/services/test_websocket_push_service_lifecycle.py b/tests/services/test_websocket_push_service_lifecycle.py new file mode 100644 index 00000000..3b17d64c --- /dev/null +++ b/tests/services/test_websocket_push_service_lifecycle.py @@ -0,0 +1,145 @@ +"""WebSocket推送服务生命周期与合约测试""" +from unittest.mock import patch + +import pandas as pd + +from app.services.websocket_push_service import WebSocketPushService + + +# ---------- lifecycle ---------- + +def test_stop_push_service_sets_is_running_false(): + """stop 应将 is_running 置 False""" + service = WebSocketPushService() + service.is_running = True + service.stop_push_service() + assert service.is_running is False + + +def test_start_push_service_refuses_double_start(): + """运行中再次 start 应 warning 并跳过""" + service = WebSocketPushService() + service.is_running = True + with patch("app.extensions.socketio") as mock_sio: + service.start_push_service() + mock_sio.start_background_task.assert_not_called() + + +# ---------- _push_market_data ---------- + +def test_push_market_data_uses_available_stocks_and_fallback_period(): + """应使用 get_available_minute_stocks + get_minute_latest_data""" + service = WebSocketPushService() + + fake_df = pd.DataFrame([ + {"datetime": "2026-01-01 10:00", "open": 10.0, "high": 10.5, + "low": 9.8, "close": 10.2, "volume": 1000, "amount": 10000}, + {"datetime": "2026-01-01 09:59", "open": 10.1, "high": 10.3, + "low": 9.9, "close": 10.0, "volume": 900, "amount": 9000}, + ]) + + with patch.object(service.data_manager, "get_available_minute_stocks", + return_value=["000001.SZ"]), \ + patch.object(service.data_manager, "get_minute_latest_data", + return_value=fake_df), \ + patch("app.services.websocket_push_service.broadcast_market_data") as bc: + + service._push_market_data() + + assert bc.call_count == 2 # per-stock + 'all' + # First call: specific stock + assert bc.call_args_list[0].args[0] == "000001.SZ" + # Second call: broadcast to 'all' + assert bc.call_args_list[1].args[0] == "all" + + +def test_push_market_data_skips_empty_data(): + """get_minute_latest_data 返回空 DataFrame 时跳过该股票""" + service = WebSocketPushService() + + with patch.object(service.data_manager, "get_available_minute_stocks", + return_value=["000001.SZ"]), \ + patch.object(service.data_manager, "get_minute_latest_data", + return_value=pd.DataFrame()), \ + patch("app.services.websocket_push_service.broadcast_market_data") as bc: + + service._push_market_data() + + bc.assert_not_called() + + +# ---------- _push_monitor_data ---------- + +def test_push_monitor_data_unpacks_anomaly_list(): + """detect_anomalies 返回的包装 dict 应解包为 anomalies 列表""" + service = WebSocketPushService() + + anomaly_response = { + "success": True, + "data": { + "anomalies": [{"ts_code": "000001.SZ", "anomaly_types": ["急涨"]}], + "total_count": 1, + }, + } + + with patch.object(service.data_manager, "get_market_overview", + return_value={"total_stocks": 100}), \ + patch.object(service.monitor_service, "detect_anomalies", + return_value=anomaly_response), \ + patch.object(service.monitor_service, "get_market_sentiment", + return_value={"sentiment": "neutral"}), \ + patch("app.services.websocket_push_service.broadcast_monitor_data") as bc: + + service._push_monitor_data() + + payload = bc.call_args.args[0] + # top_movers 和 anomalies 都应该是 list,不是包装 dict + assert isinstance(payload["top_movers"], list) + assert isinstance(payload["anomalies"], list) + assert payload["top_movers"][0]["ts_code"] == "000001.SZ" + + +def test_push_monitor_data_handles_detect_failure(): + """detect_anomalies 返回 success=False 时应降级为空列表""" + service = WebSocketPushService() + + with patch.object(service.data_manager, "get_market_overview", + return_value={}), \ + patch.object(service.monitor_service, "detect_anomalies", + return_value={"success": False, "message": "error"}), \ + patch.object(service.monitor_service, "get_market_sentiment", + return_value={}), \ + patch("app.services.websocket_push_service.broadcast_monitor_data") as bc: + + service._push_monitor_data() + + payload = bc.call_args.args[0] + assert payload["top_movers"] == [] + assert payload["anomalies"] == [] + + +# ---------- _calculate_change_pct ---------- + +def test_calculate_change_pct_with_two_rows(): + service = WebSocketPushService() + df = pd.DataFrame([ + {"close": 11.0}, + {"close": 10.0}, + ]) + pct = service._calculate_change_pct(df) + assert pct == 10.0 + + +def test_calculate_change_pct_single_row_fallback(): + service = WebSocketPushService() + df = pd.DataFrame([ + {"close": 11.0, "open": 10.0}, + ]) + pct = service._calculate_change_pct(df) + assert pct == 10.0 + + +def test_calculate_change_pct_empty_df(): + service = WebSocketPushService() + pct = service._calculate_change_pct(pd.DataFrame()) + assert pct == 0.0