diff --git a/app/__init__.py b/app/__init__.py index 96dbc9c94..c05708298 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -37,6 +37,9 @@ def create_app(config_name='default'): from app.api.data_jobs_api import data_jobs_bp from app.routes.ml_factor_routes import ml_factor_routes from app.routes.realtime_analysis_routes import realtime_analysis_routes + from app.routes.heatmap import heatmap_routes + from app.routes.pattern_screen import pattern_screen_bp + from app.api.pattern_screen_api import pattern_screen_api app.register_blueprint(api_bp, url_prefix='/api') app.register_blueprint(ml_factor_bp) app.register_blueprint(text2sql_bp) @@ -50,7 +53,10 @@ def create_app(config_name='default'): app.register_blueprint(data_jobs_bp) app.register_blueprint(ml_factor_routes) app.register_blueprint(realtime_analysis_routes) - + app.register_blueprint(heatmap_routes) + app.register_blueprint(pattern_screen_bp) + app.register_blueprint(pattern_screen_api) + from app.main import main_bp app.register_blueprint(main_bp) diff --git a/app/api/pattern_screen_api.py b/app/api/pattern_screen_api.py new file mode 100644 index 000000000..4de13a87f --- /dev/null +++ b/app/api/pattern_screen_api.py @@ -0,0 +1,47 @@ +"""形态选股 API 端点。""" +from flask import Blueprint, request, jsonify +from loguru import logger + +pattern_screen_api = Blueprint('pattern_screen_api', __name__, url_prefix='/api/pattern-screen') + + +@pattern_screen_api.route('/groups', methods=['GET']) +def get_groups(): + """返回形态分组元数据(含命中数)。""" + try: + from app.services.pattern_screen_service import get_pattern_screen_service + svc = get_pattern_screen_service() + groups = svc.get_groups() + return jsonify({'code': 200, 'message': '成功', 'data': groups}) + except Exception as e: + logger.error(f"获取形态分组失败: {e}") + return jsonify({'code': 500, 'message': f'服务器错误: {str(e)}', 'data': None}), 500 + + +@pattern_screen_api.route('/screen', methods=['POST']) +def screen(): + """执行形态筛选,返回结果表格。""" + try: + data = request.get_json() or {} + patterns = data.get('patterns', []) + sort_by = data.get('sort_by', 'pct_chg') + order = data.get('order', 'desc') + limit = data.get('limit', 50) + offset = data.get('offset', 0) + + from app.services.pattern_screen_service import get_pattern_screen_service + svc = get_pattern_screen_service() + result = svc.screen( + patterns=patterns, + sort_by=sort_by, + order=order, + limit=limit, + offset=offset, + ) + return jsonify({'code': 200, 'message': '成功', 'data': result}) + except ValueError as e: + logger.warning(f"形态筛选参数错误: {e}") + return jsonify({'code': 400, 'message': str(e), 'data': None}), 400 + except Exception as e: + logger.error(f"形态筛选失败: {e}") + return jsonify({'code': 500, 'message': f'服务器错误: {str(e)}', 'data': None}), 500 diff --git a/app/routes/heatmap.py b/app/routes/heatmap.py new file mode 100644 index 000000000..9a089effb --- /dev/null +++ b/app/routes/heatmap.py @@ -0,0 +1,30 @@ +"""板块热力图页面路由。""" +from flask import Blueprint, render_template +from loguru import logger +from app.services.heatmap_service import HeatmapService + +heatmap_routes = Blueprint('heatmap_routes', __name__, url_prefix='/heatmap') + + +@heatmap_routes.route('/') +def index(): + """板块热力图页面。""" + try: + service = HeatmapService() + sectors, stocks = service.get_heatmap_data() + trade_date = sectors[0]['trade_date'] if sectors else '' + return render_template( + 'heatmap.html', + sectors_json=sectors, + stocks_json=stocks, + trade_date=trade_date, + ) + except Exception as e: + logger.error(f"热力图加载失败: {e}") + return render_template( + 'heatmap.html', + sectors_json=[], + stocks_json=[], + trade_date='', + error='数据加载失败,请确认 data/data.parquet 是否存在', + ) diff --git a/app/routes/pattern_screen.py b/app/routes/pattern_screen.py new file mode 100644 index 000000000..5b56eee08 --- /dev/null +++ b/app/routes/pattern_screen.py @@ -0,0 +1,10 @@ +"""形态选股页面路由。""" +from flask import Blueprint, render_template + +pattern_screen_bp = Blueprint('pattern_screen', __name__) + + +@pattern_screen_bp.route('/pattern-screen/') +def index(): + """形态选股页面。""" + return render_template('pattern_screen.html') diff --git a/app/services/heatmap_service.py b/app/services/heatmap_service.py new file mode 100644 index 000000000..889129cfd --- /dev/null +++ b/app/services/heatmap_service.py @@ -0,0 +1,57 @@ +"""板块热力图数据服务 — 读取 data/data.parquet 并按行业聚合。""" +import os +import pandas as pd +import numpy as np +from loguru import logger + + +class HeatmapService: + """板块热力图数据聚合服务。""" + + def __init__(self): + self._data_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + 'data', 'data.parquet' + ) + + def get_heatmap_data(self): + """读取 parquet,返回 (sectors_json, stocks_json)。 + + Returns: + tuple: (list[dict], list[dict]) + - sectors: 板块聚合数据,按 total_mv 降序 + - stocks: 全量个股数据(已过滤停牌) + """ + df = pd.read_parquet(self._data_path) + trade_date = str(df['trade_date'].iloc[0]) + + # 过滤停牌/退市 + df = df[df['close'] > 0].copy() + + # 板块聚合 + sectors = [] + for industry, group in df.groupby('industry'): + total_mv_sum = group['total_mv'].sum() + if total_mv_sum == 0: + continue + avg_pct_chg = np.average(group['pct_chg'], weights=group['total_mv']) + sectors.append({ + 'name': industry, + 'avg_pct_chg': round(avg_pct_chg, 2), + 'total_mv': round(total_mv_sum, 2), + 'stock_count': len(group), + 'up_count': int((group['pct_chg'] > 0).sum()), + 'down_count': int((group['pct_chg'] < 0).sum()), + 'net_mf_amount': round(group['net_mf_amount'].sum(), 2), + 'trade_date': trade_date, + }) + + sectors.sort(key=lambda x: x['total_mv'], reverse=True) + + # 个股明细 + stock_cols = ['ts_code', 'name', 'industry', 'pct_chg', 'close', + 'total_mv', 'net_mf_amount', 'turnover_rate'] + stocks_df = df[stock_cols].sort_values('pct_chg', ascending=False) + stocks = stocks_df.where(stocks_df.notna(), None).to_dict(orient='records') + + return sectors, stocks diff --git a/app/services/pattern_screen_service.py b/app/services/pattern_screen_service.py new file mode 100644 index 000000000..a50e205ac --- /dev/null +++ b/app/services/pattern_screen_service.py @@ -0,0 +1,345 @@ +"""Pattern screen service for stock pattern filtering and metadata. + +Provides pattern group metadata and AND-filtered screening across pattern columns +in the market data parquet file. +""" + +from flask import current_app +import pandas as pd +import numpy as np + + +# Pattern group metadata with full field list +PATTERN_GROUPS = [ + { + 'id': 'single_candle', + 'label': '单根K线', + 'fields': [ + {'key': 'pattern_bull_candle', 'label': '阳线'}, + {'key': 'pattern_bear_candle', 'label': '阴线'}, + {'key': 'pattern_hammer', 'label': '锤头线'}, + {'key': 'pattern_doji', 'label': '十字星'}, + {'key': 'pattern_spinning_top', 'label': '纺锤线'}, + {'key': 'pattern_shooting_star', 'label': '流星线'}, + {'key': 'pattern_long_upper_shadow', 'label': '长上影线'}, + {'key': 'pattern_long_lower_shadow', 'label': '长下影线'}, + {'key': 'pattern_gravestone_doji', 'label': '墓碑十字'}, + {'key': 'pattern_dragonfly_doji', 'label': '蜻蜓十字'}, + {'key': 'pattern_hanging_man', 'label': '上吊线'}, + {'key': 'pattern_inverted_hammer', 'label': '倒锤头'}, + {'key': 'pattern_big_bull', 'label': '大阳线'}, + {'key': 'pattern_big_bear', 'label': '大阴线'}, + {'key': 'pattern_medium_bull', 'label': '中阳线'}, + {'key': 'pattern_medium_bear', 'label': '中阴线'}, + {'key': 'pattern_small_bull', 'label': '小阳线'}, + {'key': 'pattern_small_bear', 'label': '小阴线'}, + {'key': 'pattern_no_body', 'label': '无实体线'}, + {'key': 'pattern_no_upper_bull', 'label': '光头阳线'}, + {'key': 'pattern_no_upper_bear', 'label': '光头阴线'}, + {'key': 'pattern_no_lower_bull', 'label': '光脚阳线'}, + {'key': 'pattern_no_lower_bear', 'label': '光脚阴线'}, + {'key': 'pattern_t_shape', 'label': 'T字线'}, + {'key': 'pattern_inverted_t_shape', 'label': '倒T字线'}, + {'key': 'pattern_low_open_high_close', 'label': '低开高走'}, + {'key': 'pattern_high_open_low_close', 'label': '高开低走'}, + {'key': 'pattern_gap_up', 'label': '跳空高开'}, + {'key': 'pattern_gap_down', 'label': '跳空低开'}, + {'key': 'pattern_close_above_prev_close', 'label': '收盘站上前收'}, + {'key': 'pattern_close_below_prev_close', 'label': '收盘跌破前收'}, + {'key': 'pattern_gap_reclaim_prev_close', 'label': '低开收回前收'}, + {'key': 'pattern_gap_fade_below_prev_close', 'label': '高开回落失守前收'}, + {'key': 'pattern_close_high', 'label': '收盘近最高'}, + {'key': 'pattern_flat_open_high_close', 'label': '平开高走'}, + {'key': 'pattern_flat_open_low_close', 'label': '平开低走'}, + {'key': 'pattern_gap_up_close_bull', 'label': '高开收阳'}, + {'key': 'pattern_gap_down_close_bear', 'label': '低开收阴'}, + {'key': 'pattern_open_near_high_close_high', 'label': '开盘即最高附近收盘'}, + {'key': 'pattern_open_near_low_close_low', 'label': '开盘即最低附近收盘'}, + {'key': 'pattern_flat_open', 'label': '平开'}, + {'key': 'pattern_gap_up_fill', 'label': '高开补缺'}, + {'key': 'pattern_gap_down_fill', 'label': '低开补缺'}, + {'key': 'pattern_pin_bar', 'label': 'Pin Bar'}, + {'key': 'pattern_reversal_prelude', 'label': '反包前兆'}, + {'key': 'pattern_high_resistance', 'label': '高位受阻'}, + {'key': 'pattern_low_stabilization', 'label': '低位止跌'}, + {'key': 'pattern_break_prev_high', 'label': '向上突破前高'}, + {'key': 'pattern_break_prev_low', 'label': '向下跌破前低'}, + {'key': 'pattern_false_break', 'label': '假突破回落'}, + {'key': 'pattern_false_breakdown_recovery', 'label': '假跌破回升'}, + ] + }, + { + 'id': 'double_candle', + 'label': '双根K线', + 'fields': [ + {'key': 'pattern_bullish_engulfing', 'label': '阳包阴'}, + {'key': 'pattern_bearish_engulfing', 'label': '阴包阳'}, + {'key': 'pattern_inside_bar', 'label': '孕线'}, + {'key': 'pattern_dark_cloud', 'label': '乌云盖顶'}, + {'key': 'pattern_piercing', 'label': '刺透形态'}, + {'key': 'pattern_tweezer_top', 'label': '镊子顶'}, + {'key': 'pattern_tweezer_bottom', 'label': '镊子底'}, + {'key': 'pattern_gap_break', 'label': '跳空上攻'}, + {'key': 'pattern_gap_down_break', 'label': '跳空下跌'}, + {'key': 'pattern_gap_up_no_fill', 'label': '跳空不补上行'}, + {'key': 'pattern_gap_down_no_fill', 'label': '跳空不补下行'}, + {'key': 'pattern_reversal_bar', 'label': '反转包线'}, + {'key': 'pattern_flat_top', 'label': '平头顶'}, + {'key': 'pattern_flat_bottom', 'label': '平头底'}, + {'key': 'pattern_island_reversal', 'label': '岛形反转'}, + {'key': 'pattern_t_limit', 'label': 'T字板'}, + {'key': 'pattern_limit_reversal_wrap', 'label': '涨停反包'}, + {'key': 'pattern_vol_up', 'label': '放量上涨'}, + {'key': 'pattern_vol_down', 'label': '缩量下跌'}, + {'key': 'pattern_double_volume_bar', 'label': '倍量柱'}, + {'key': 'pattern_breakout_volume_confirm', 'label': '突破放量确认'}, + ] + }, + { + 'id': 'triple_candle', + 'label': '三根K线', + 'fields': [ + {'key': 'pattern_morning_star', 'label': '早晨之星'}, + {'key': 'pattern_evening_star', 'label': '黄昏之星'}, + {'key': 'pattern_morning_doji_star', 'label': '启明星'}, + {'key': 'pattern_evening_doji_star', 'label': '黄昏十字星'}, + {'key': 'pattern_three_black_crows', 'label': '三只乌鸦'}, + {'key': 'pattern_red_three', 'label': '红三兵'}, + {'key': 'pattern_three_outside_up', 'label': '三外升'}, + {'key': 'pattern_three_outside_down', 'label': '三外降'}, + {'key': 'pattern_rising_three_methods', 'label': '上升三法'}, + {'key': 'pattern_falling_three_methods', 'label': '下降三法'}, + {'key': 'pattern_three_up', 'label': '三连阳'}, + {'key': 'pattern_three_down', 'label': '三连阴'}, + {'key': 'pattern_three_yang_kaitai', 'label': '三阳开泰'}, + {'key': 'pattern_three_yin_breakdown', 'label': '三阴破位'}, + ] + }, + { + 'id': 'trend_structure', + 'label': '趋势结构', + 'fields': [ + {'key': 'pattern_up_trend', 'label': '上升趋势'}, + {'key': 'pattern_down_trend', 'label': '下降趋势'}, + {'key': 'pattern_sideways', 'label': '横盘整理'}, + {'key': 'pattern_golden_cross', 'label': '均线金叉'}, + {'key': 'pattern_duck_head', 'label': '老鸭头'}, + {'key': 'pattern_double_bottom', 'label': '双底'}, + {'key': 'pattern_arc_bottom', 'label': '圆弧底'}, + {'key': 'pattern_ma_bull', 'label': '均线多头排列'}, + {'key': 'pattern_high_tight', 'label': '高位强势整理'}, + {'key': 'pattern_pullback_hold', 'label': '回踩不破'}, + {'key': 'pattern_trend_continue', 'label': '趋势中继'}, + {'key': 'pattern_ma_spread_bull', 'label': '均线发散多头'}, + ] + }, + { + 'id': 'volume_price', + 'label': '量价关系', + 'fields': [ + {'key': 'pattern_box_breakout', 'label': '箱体放量突破'}, + {'key': 'pattern_vol_price_up', 'label': '量价齐升'}, + {'key': 'pattern_platform_break', 'label': '平台突破'}, + {'key': 'pattern_triangle_squeeze', 'label': '三角收敛突破'}, + {'key': 'pattern_limit_turnover_strong', 'label': '涨停换手强'}, + {'key': 'pattern_price_volume_bear_divergence', 'label': '价量顶背离'}, + {'key': 'pattern_price_volume_bull_divergence', 'label': '价量底背离'}, + {'key': 'pattern_price_down_volume_up', 'label': '价跌量增'}, + {'key': 'pattern_volume_staircase', 'label': '量能阶梯'}, + {'key': 'pattern_pullback_volume_shrink', 'label': '回调缩量'}, + {'key': 'pattern_high_turnover', 'label': '高换手'}, + {'key': 'pattern_limit_up_volume_shrink', 'label': '一字板缩量'}, + {'key': 'pattern_false_breakout_volume_weak', 'label': '假突破量弱'}, + {'key': 'pattern_floor_volume_price', 'label': '地量价稳'}, + {'key': 'pattern_blowoff_volume_price', 'label': '天量滞涨'}, + {'key': 'pattern_v_reversal', 'label': 'V型反转'}, + ] + }, + { + 'id': 'compound', + 'label': '复合形态', + 'fields': [ + {'key': 'pattern_first_limit', 'label': '首板'}, + {'key': 'pattern_multi_limit', 'label': '连板'}, + {'key': 'pattern_one_word_limit', 'label': '一字板'}, + {'key': 'pattern_limit_down_to_up', 'label': '地天板'}, + {'key': 'pattern_lotus_breakout', 'label': '莲花突破'}, + {'key': 'pattern_midway_refuel', 'label': '中继加油'}, + {'key': 'pattern_consolidation_platform', 'label': '整理平台'}, + {'key': 'pattern_n_breakout', 'label': 'N字突破'}, + {'key': 'pattern_gap_breakaway', 'label': '跳空突破'}, + {'key': 'pattern_channel_breakout', 'label': '通道突破'}, + {'key': 'pattern_flag_breakout', 'label': '旗形突破'}, + ] + }, + { + 'id': 'momentum', + 'label': '动量突破', + 'fields': [ + {'key': 'break_high_20', 'label': '突破20日新高'}, + {'key': 'break_high_60', 'label': '突破60日新高'}, + {'key': 'break_high_120', 'label': '突破120日新高'}, + {'key': 'break_high_250', 'label': '突破250日新高'}, + {'key': 'consec_up_3', 'label': '连续上涨3日'}, + {'key': 'consec_up_5', 'label': '连续上涨5日'}, + ] + }, +] + +# Allowed sort columns +SORT_WHITELIST = [ + "ts_code", "pct_chg", "close", "amount", "total_mv", + "turnover_rate", "vol_ratio_5", "consec_up_days" +] + +# Columns to return in screen results +DISPLAY_COLUMNS = [ + "ts_code", "name", "industry", "pct_chg", "close", "amount", + "total_mv", "turnover_rate", "vol_ratio_5" +] + + +class PatternScreenService: + """Service for pattern metadata and AND-filtered screening.""" + + def __init__(self): + self._df = None + + def _load_df(self) -> pd.DataFrame: + """Load market data from parquet file.""" + data_dir = current_app.config.get('DATA_DIR', 'data') + path = f"{data_dir}/data.parquet" + return pd.read_parquet(path) + + def _ensure_df(self): + """Lazy-load DataFrame if not cached.""" + if self._df is None: + self._df = self._load_df() + + def get_groups(self) -> list: + """Return pattern groups with hit counts filtered by DataFrame columns. + + Returns: + List of group dicts with id, label, and fields (each field has + key, label, count). Fields whose columns are not present in the + DataFrame are excluded. + """ + self._ensure_df() + df_cols = set(self._df.columns) + + groups = [] + for group in PATTERN_GROUPS: + fields = [] + for field in group['fields']: + if field['key'] in df_cols: + count = int(self._df[field['key']].fillna(0).sum()) + fields.append({ + 'key': field['key'], + 'label': field['label'], + 'count': count + }) + if fields: + groups.append({ + 'id': group['id'], + 'label': group['label'], + 'fields': fields + }) + return groups + + def screen( + self, + patterns: list = None, + sort_by: str = "pct_chg", + order: str = "desc", + limit: int = 20, + offset: int = 0 + ) -> dict: + """Screen stocks by pattern filters with AND logic. + + Args: + patterns: List of pattern column keys to filter (AND logic). + sort_by: Column to sort by (must be in SORT_WHITELIST). + order: 'asc' or 'desc'. + limit: Max rows to return (capped at 500). + offset: Rows to skip for pagination. + + Returns: + Dict with keys: total, offset, limit, trade_date, rows (list of dicts). + + Raises: + ValueError: If sort_by, order, or pattern keys are invalid. + """ + if patterns is None: + patterns = [] + + self._ensure_df() + df = self._df.copy() + + # Validate sort_by + if sort_by not in SORT_WHITELIST: + raise ValueError(f"sort_by must be one of {SORT_WHITELIST}, got: {sort_by}") + + # Validate order + if order not in ('asc', 'desc'): + raise ValueError(f"order must be 'asc' or 'desc', got: {order}") + + # Validate pattern keys + df_cols = set(df.columns) + for p in patterns: + if p not in df_cols: + raise ValueError(f"Pattern '{p}' not found in data") + + # Apply AND filter: all selected patterns must be 1 + if patterns: + mask = df[patterns].fillna(0).eq(1).all(axis=1) + df = df[mask] + + # Get trade_date from first row (if any) + trade_date = None + if not df.empty and 'trade_date' in df.columns: + trade_date = str(df['trade_date'].iloc[0]) + + # Sort + ascending = (order == 'asc') + df = df.sort_values(by=sort_by, ascending=ascending) + + # Total before pagination + total = len(df) + + # Cap limit + limit = min(limit, 500) + + # Select display columns (only those present) + cols_to_show = [c for c in DISPLAY_COLUMNS if c in df.columns] + df = df[cols_to_show] + + # Paginate + df = df.iloc[offset:offset + limit] + + # Convert to dict of dicts, NaN -> None + df = df.replace({np.nan: None}) + rows = df.to_dict(orient='records') + + return { + 'total': total, + 'offset': offset, + 'limit': limit, + 'trade_date': trade_date, + 'rows': rows + } + + def invalidate_cache(self): + """Clear cached DataFrame to force reload on next access.""" + self._df = None + + +# Module-level singleton +_service = None + + +def get_pattern_screen_service() -> PatternScreenService: + """Get or create the singleton PatternScreenService instance.""" + global _service + if _service is None: + _service = PatternScreenService() + return _service diff --git a/app/templates/base.html b/app/templates/base.html index c31193e08..4c8d2cf80 100644 --- a/app/templates/base.html +++ b/app/templates/base.html @@ -331,6 +331,16 @@ 选股筛选 + + +``` + +- [ ] **Step 4: Create the template** + +Create `app/templates/pattern_screen.html` with the following full content: + +```html +{% extends "base.html" %} +{% block title %}形态选股{% endblock %} +{% block extra_css %} + +{% endblock %} + +{% block content %} +
+ +
+ +
+
+ + +
+
+ + +
+
+ 交易日: -- + | + 已选: 0 个形态 + | + 匹配: 0 +
+
+
+ +
加载中...
+
+
+
+
+
+ + +{% endblock %} +``` + +- [ ] **Step 5: Verify page loads** + +Run: `python run.py` and open `http://localhost:5000/pattern-screen/` +Expected: page renders with left panel showing grouped checkboxes and right panel showing stock table + +- [ ] **Step 6: Verify nav link works** + +From any page, click "形态选股" in the nav bar. +Expected: navigates to `/pattern-screen/` + +- [ ] **Step 7: Commit** + +```bash +git add app/routes/pattern_screen.py app/templates/pattern_screen.html app/__init__.py app/templates/base.html +git commit -m "feat: add pattern screening page with left panel + right table layout" +``` + +--- + +### Task 4: Integration verification + +**Files:** None new — verification only. + +- [ ] **Step 1: Run all tests** + +Run: `pytest -v` +Expected: all existing tests + new pattern screen tests pass + +- [ ] **Step 2: Manual smoke test** + +1. `python run.py` +2. Open `/pattern-screen/` +3. Verify: groups load with hit counts +4. Check a pattern (e.g. "均线金叉") → click "筛选" → verify table filters +5. Check a second pattern (e.g. "均线多头排列") → verify AND logic (fewer results) +6. Click a table header → verify sorting changes +7. Click page 2 → verify pagination +8. Click "重置" → verify all checkboxes clear and table shows all stocks +9. Type in search box → verify checkbox labels filter + +- [ ] **Step 3: Commit any fixes** + +```bash +git add -A && git commit -m "fix: pattern screen integration fixes" +``` diff --git a/docs/superpowers/plans/2026-06-07-sector-heatmap.md b/docs/superpowers/plans/2026-06-07-sector-heatmap.md new file mode 100644 index 000000000..f974450c5 --- /dev/null +++ b/docs/superpowers/plans/2026-06-07-sector-heatmap.md @@ -0,0 +1,570 @@ +# 板块热力图 Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add a Treemap heatmap page showing A-share sector performance with in-page drill-down to individual stocks. + +**Architecture:** Flask route → HeatmapService reads `data/data.parquet` → injects JSON into Jinja2 template → ECharts renders treemap + JS handles click-to-expand stock table. No new dependencies. + +**Tech Stack:** Flask, pandas, ECharts 5.4.3, Bootstrap 5.1.3, Jinja2 + +**Spec:** `docs/superpowers/specs/2026-06-07-sector-heatmap-design.md` + +--- + +## File Structure + +| Action | File | Responsibility | +|--------|------|----------------| +| Create | `app/services/heatmap_service.py` | Read parquet, aggregate by industry, return JSON | +| Create | `app/routes/heatmap.py` | Blueprint with `GET /heatmap` route | +| Create | `app/templates/heatmap.html` | ECharts treemap + expandable stock table | +| Modify | `app/__init__.py:38-52` | Register `heatmap_routes` blueprint | +| Modify | `app/templates/base.html:333` | Add nav link before 多因子模型 dropdown | +| Create | `tests/services/test_heatmap_service.py` | Unit tests for HeatmapService | + +--- + +### Task 1: HeatmapService — Failing Tests + +**Files:** +- Create: `tests/services/test_heatmap_service.py` + +- [ ] **Step 1: Write failing tests** + +```python +"""Tests for HeatmapService — sector aggregation logic.""" +import json +import pytest +from unittest.mock import patch, MagicMock +import pandas as pd +import numpy as np + + +@pytest.fixture +def sample_df(): + """Minimal DataFrame matching data/data.parquet schema.""" + return pd.DataFrame({ + 'ts_code': ['000001.SZ', '000002.SZ', '000003.SZ', '000004.SZ', '000005.SZ'], + 'name': ['平安银行', '万科A', '测试银行', '测试地产', '停牌股'], + 'industry': ['银行', '全国地产', '银行', '全国地产', '全国地产'], + 'pct_chg': [1.5, -2.0, 0.5, -1.0, 0.0], + 'close': [11.0, 3.5, 5.0, 8.0, 0.0], + 'total_mv': [21300, 3900, 5000, 2000, 100], + 'net_mf_amount': [14000, -5000, 3000, -2000, 0], + 'turnover_rate': [0.5, 1.7, 1.0, 2.0, 0.0], + 'trade_date': ['20260605'] * 5, + }) + + +@pytest.fixture +def service(): + from app.services.heatmap_service import HeatmapService + return HeatmapService() + + +class TestGetHeatmapData: + """Test HeatmapService.get_heatmap_data().""" + + @patch('app.services.heatmap_service.pd.read_parquet') + def test_returns_two_lists(self, mock_read, service, sample_df): + mock_read.return_value = sample_df + sectors, stocks = service.get_heatmap_data() + assert isinstance(sectors, list) + assert isinstance(stocks, list) + + @patch('app.services.heatmap_service.pd.read_parquet') + def test_filters_suspended_stocks(self, mock_read, service, sample_df): + """Stocks with close == 0 should be excluded.""" + mock_read.return_value = sample_df + sectors, stocks = service.get_heatmap_data() + # 停牌股 (close=0) should not appear in stocks + stock_names = [s['name'] for s in stocks] + assert '停牌股' not in stock_names + + @patch('app.services.heatmap_service.pd.read_parquet') + def test_sector_count(self, mock_read, service, sample_df): + mock_read.return_value = sample_df + sectors, _ = service.get_heatmap_data() + industry_names = [s['name'] for s in sectors] + assert len(industry_names) == 2 # 银行 + 全国地产 + + @patch('app.services.heatmap_service.pd.read_parquet') + def test_sector_weighted_pct_chg(self, mock_read, service, sample_df): + """avg_pct_chg should be market-cap weighted average.""" + mock_read.return_value = sample_df + sectors, _ = service.get_heatmap_data() + bank = next(s for s in sectors if s['name'] == '银行') + # Weighted: (1.5*21300 + 0.5*5000) / (21300+5000) = 34450/26300 ≈ 1.3103 + assert abs(bank['avg_pct_chg'] - 1.3103) < 0.01 + + @patch('app.services.heatmap_service.pd.read_parquet') + def test_sector_stock_count(self, mock_read, service, sample_df): + mock_read.return_value = sample_df + sectors, _ = service.get_heatmap_data() + bank = next(s for s in sectors if s['name'] == '银行') + assert bank['stock_count'] == 2 # 2 银行 after filtering + + @patch('app.services.heatmap_service.pd.read_parquet') + def test_sector_up_down_count(self, mock_read, service, sample_df): + mock_read.return_value = sample_df + sectors, _ = service.get_heatmap_data() + realestate = next(s for s in sectors if s['name'] == '全国地产') + # 万科 -2.0, 测试地产 -1.0 (停牌股 filtered out) + assert realestate['down_count'] == 2 + assert realestate['up_count'] == 0 + + @patch('app.services.heatmap_service.pd.read_parquet') + def test_stocks_have_required_fields(self, mock_read, service, sample_df): + mock_read.return_value = sample_df + _, stocks = service.get_heatmap_data() + required = {'name', 'ts_code', 'pct_chg', 'close', 'total_mv', + 'net_mf_amount', 'turnover_rate', 'industry'} + for s in stocks: + assert required.issubset(s.keys()), f"Missing keys: {required - s.keys()}" + + @patch('app.services.heatmap_service.pd.read_parquet') + def test_trade_date_returned(self, mock_read, service, sample_df): + mock_read.return_value = sample_df + sectors, _ = service.get_heatmap_data() + assert sectors[0]['trade_date'] == '20260605' +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `pytest tests/services/test_heatmap_service.py -v` +Expected: FAIL — `ModuleNotFoundError: No module named 'app.services.heatmap_service'` + +--- + +### Task 2: HeatmapService — Implementation + +**Files:** +- Create: `app/services/heatmap_service.py` + +- [ ] **Step 1: Write HeatmapService** + +```python +"""板块热力图数据服务 — 读取 data/data.parquet 并按行业聚合。""" +import os +import pandas as pd +import numpy as np +from loguru import logger + + +class HeatmapService: + """板块热力图数据聚合服务。""" + + def __init__(self): + self._data_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + 'data', 'data.parquet' + ) + + def get_heatmap_data(self): + """读取 parquet,返回 (sectors_json, stocks_json)。 + + Returns: + tuple: (list[dict], list[dict]) + - sectors: 板块聚合数据,按 total_mv 降序 + - stocks: 全量个股数据(已过滤停牌) + """ + df = pd.read_parquet(self._data_path) + trade_date = str(df['trade_date'].iloc[0]) + + # 过滤停牌/退市 + df = df[df['close'] > 0].copy() + + # 板块聚合 + sectors = [] + for industry, group in df.groupby('industry'): + total_mv_sum = group['total_mv'].sum() + if total_mv_sum == 0: + continue + avg_pct_chg = np.average(group['pct_chg'], weights=group['total_mv']) + sectors.append({ + 'name': industry, + 'avg_pct_chg': round(avg_pct_chg, 2), + 'total_mv': round(total_mv_sum, 2), + 'stock_count': len(group), + 'up_count': int((group['pct_chg'] > 0).sum()), + 'down_count': int((group['pct_chg'] < 0).sum()), + 'net_mf_amount': round(group['net_mf_amount'].sum(), 2), + 'trade_date': trade_date, + }) + + sectors.sort(key=lambda x: x['total_mv'], reverse=True) + + # 个股明细 + stock_cols = ['ts_code', 'name', 'industry', 'pct_chg', 'close', + 'total_mv', 'net_mf_amount', 'turnover_rate'] + stocks_df = df[stock_cols].sort_values('pct_chg', ascending=False) + stocks = stocks_df.where(stocks_df.notna(), None).to_dict(orient='records') + + return sectors, stocks +``` + +- [ ] **Step 2: Run tests to verify they pass** + +Run: `pytest tests/services/test_heatmap_service.py -v` +Expected: All 8 tests PASS + +- [ ] **Step 3: Commit** + +```bash +git add app/services/heatmap_service.py tests/services/test_heatmap_service.py +git commit -m "feat: add HeatmapService with sector aggregation logic" +``` + +--- + +### Task 3: Heatmap Route + +**Files:** +- Create: `app/routes/heatmap.py` + +- [ ] **Step 1: Create route blueprint** + +```python +"""板块热力图页面路由。""" +from flask import Blueprint, render_template +from loguru import logger +from app.services.heatmap_service import HeatmapService + +heatmap_routes = Blueprint('heatmap_routes', __name__, url_prefix='/heatmap') + + +@heatmap_routes.route('/') +def index(): + """板块热力图页面。""" + try: + service = HeatmapService() + sectors, stocks = service.get_heatmap_data() + trade_date = sectors[0]['trade_date'] if sectors else '' + return render_template( + 'heatmap.html', + sectors_json=sectors, + stocks_json=stocks, + trade_date=trade_date, + ) + except Exception as e: + logger.error(f"热力图加载失败: {e}") + return render_template( + 'heatmap.html', + sectors_json=[], + stocks_json=[], + trade_date='', + error='数据加载失败,请确认 data/data.parquet 是否存在', + ) +``` + +- [ ] **Step 2: Commit** + +```bash +git add app/routes/heatmap.py +git commit -m "feat: add heatmap route blueprint" +``` + +--- + +### Task 4: Register Blueprint + Nav Link + +**Files:** +- Modify: `app/__init__.py:38-52` +- Modify: `app/templates/base.html:333` + +- [ ] **Step 1: Register heatmap_routes in app/__init__.py** + +Add import at line 39 (after `realtime_analysis_routes`): +```python +from app.routes.heatmap import heatmap_routes +``` + +Add registration at line 52 (after `app.register_blueprint(realtime_analysis_routes)`): +```python +app.register_blueprint(heatmap_routes) +``` + +- [ ] **Step 2: Add nav link in base.html** + +Insert after line 333 (after the 选股筛选 nav-item, before 多因子模型 dropdown): +```html + +``` + +- [ ] **Step 3: Verify app starts** + +Run: `python -c "from app import create_app; app = create_app('development'); print('OK')" ` +Expected: prints `OK` + +- [ ] **Step 4: Commit** + +```bash +git add app/__init__.py app/templates/base.html +git commit -m "feat: register heatmap blueprint and add nav link" +``` + +--- + +### Task 5: Heatmap Template + +**Files:** +- Create: `app/templates/heatmap.html` + +- [ ] **Step 1: Create the template** + +```html +{% extends "base.html" %} + +{% block title %}板块热力图{% endblock %} + +{% block extra_css %} + +{% endblock %} + +{% block content %} +
+ {% if error is defined and error %} +
{{ error }}
+ {% else %} +
+
+ 板块热力图 + {{ trade_date }} +
+
+ +
+ +
+
+
+
+ {% endif %} +
+{% endblock %} + +{% block extra_js %} + +{% endblock %} +``` + +- [ ] **Step 2: Commit** + +```bash +git add app/templates/heatmap.html +git commit -m "feat: add heatmap template with ECharts treemap and stock table" +``` + +--- + +### Task 6: End-to-End Verification + +- [ ] **Step 1: Run all tests** + +Run: `pytest -v` +Expected: All existing + new tests PASS + +- [ ] **Step 2: Start the app and verify** + +Run: `python run.py` +Open browser → `http://localhost:5000/heatmap` + +Verify: +- [ ] Treemap renders with colored rectangles sized by market cap +- [ ] Hover shows tooltip with sector details +- [ ] Click sector → stock table expands below +- [ ] Click same sector → table collapses +- [ ] Click different sector → table switches +- [ ] Nav bar shows "板块热力图" link and it navigates correctly +- [ ] Red = up, green = down + +- [ ] **Step 3: Final commit (if any fixes needed)** + +```bash +git add -A && git commit -m "fix: heatmap end-to-end adjustments" +``` diff --git a/docs/superpowers/specs/2026-06-07-pattern-screen-design.md b/docs/superpowers/specs/2026-06-07-pattern-screen-design.md new file mode 100644 index 000000000..414f6685e --- /dev/null +++ b/docs/superpowers/specs/2026-06-07-pattern-screen-design.md @@ -0,0 +1,180 @@ +# 形态选股功能设计文档 + +## 概述 + +为 quantitative_analysis 项目添加形态选股(Pattern Screening)功能。基于 `data/data.parquet` 宽表中的 132 个 `pattern_*` / `break_high_*` / `consec_up_*` 二值字段,提供分组筛选界面和 API。 + +参考项目:`/Users/henrylin/vscode_space/stock_screener/backend` + +## 方案选型 + +**Approach A: 独立页面 + 独立服务** + +理由:`data/data.parquet` 是预计算宽表(5524 行 × 132 pattern 列),与现有 `ParquetDataReader` 的分区表体系不同,直接 pandas 读取最简单高效,无需耦合现有选股模块。 + +## 后端架构 + +### 文件结构 + +``` +app/services/pattern_screen_service.py # 服务层 +app/api/pattern_screen_api.py # API 层 +app/routes/pattern_screen.py # 页面路由 +app/templates/pattern_screen.html # 页面模板 +``` + +### 数据层 + +- 通过 `current_app.config['DATA_DIR']` 解析路径,读取 `{DATA_DIR}/data.parquet` 到内存,缓存 DataFrame +- 132 个二值形态字段 + 基础字段(ts_code, name, industry, pct_chg, close, amount, total_mv, turnover_rate, vol_ratio_5 等) +- 提供字段元数据(分组定义、中文标签、当日命中数) +- 提供 `invalidate_cache()` 方法,在宽表重建 data job 完成后调用以刷新缓存 + +### 服务层 — PatternScreenService + +```python +class PatternScreenService: + _df: pd.DataFrame # 缓存宽表 + _field_meta: list[dict] # 分组元数据 + + def get_groups() -> list[dict] + # 返回 [{id, label, fields: [{key, label, count}]}] + + def screen(patterns, sort_by, order, limit, offset) -> dict + # 纯 AND 筛选:所有勾选字段必须 == 1 + # patterns: list[str],每个 key 必须存在于 DataFrame 列中,否则返回 400 + # sort_by: 必须在白名单内: ["pct_chg", "close", "amount", "total_mv", + # "turnover_rate", "vol_ratio_5", "consec_up_days"],默认 "pct_chg" + # order: 仅接受 "asc" 或 "desc",默认 "desc" + # limit: 1-500,默认 50 + # offset: >= 0,默认 0 + # 返回 {total, offset, limit, trade_date, rows: [...]} + + def invalidate_cache() + # 清除缓存的 DataFrame,下次调用时重新读取 parquet +``` + +### API 端点 + +| 端点 | 方法 | 用途 | +|---|---|---| +| `/api/pattern-screen/groups` | GET | 返回分组元数据(含命中数) | +| `/api/pattern-screen/screen` | POST | 执行筛选,返回结果表格 | + +POST `/api/pattern-screen/screen` 请求体: + +```json +{ + "patterns": ["pattern_golden_cross", "pattern_ma_bull"], + "sort_by": "pct_chg", + "order": "desc", + "limit": 50, + "offset": 0 +} +``` + +响应格式(遵循 `analysis_api.py` 的 `{code, message, data}` 约定): + +```json +{ + "code": 200, + "message": "成功", + "data": { + "total": 42, + "offset": 0, + "limit": 50, + "trade_date": "20260605", + "rows": [ + { + "ts_code": "000001.SZ", + "name": "平安银行", + "industry": "银行", + "pct_chg": 2.5, + "close": 12.5, + "amount": 1500000, + "total_mv": 24000000, + "turnover_rate": 1.2, + "vol_ratio_5": 1.8 + } + ] + } +} +``` + +#### 参数校验规则 + +- `patterns`: 可选,默认 `[]`(返回全部股票);若提供,每个 key 必须在 DataFrame 列中,否则 400 +- `sort_by`: 可选,默认 `"pct_chg"`;必须在白名单中,否则 400 +- `order`: 可选,默认 `"desc"`;仅接受 `"asc"` / `"desc"`,否则 400 +- `limit`: 可选,默认 50;范围 1-500,超出截断 +- `offset`: 可选,默认 0;必须 >= 0 + +### 形态分组定义 + +直接从参考项目 `meta.py` 的 `_PATTERN_GROUPS` 提取前 7 组(legacy groups): + +| 分组 ID | 中文名 | 字段数 | +|---|---|---| +| single_candle | 单K形态 | 50 | +| double_candle | 双K形态 | 21 | +| triple_candle | 三K形态 | 14 | +| trend_structure | 趋势结构 | 12 | +| volume_price | 量价形态 | 16 | +| compound | 复合形态 | 11 | +| momentum | 动量因子 | 6 | + +运行时自动过滤掉 DataFrame 中不存在的字段。 + +### Blueprint 注册 + +在 `app/__init__.py` 中注册(Pattern A:url_prefix 在 Blueprint 构造函数中声明): +- API blueprint: `pattern_screen_api = Blueprint('pattern_screen_api', __name__, url_prefix='/api/pattern-screen')`,然后 `app.register_blueprint(pattern_screen_api)` +- 页面 blueprint: `pattern_screen_bp = Blueprint('pattern_screen', __name__)`,路由 `@pattern_screen_bp.route('/pattern-screen/')` + +## 前端设计 + +### 页面布局 + +继承 `base.html`,左右分栏: + +- **左侧面板**(固定宽度 300px): + - 顶部搜索框(按中文名过滤形态) + - 分组手风琴(可折叠),每组显示 checkbox + 中文名 + 命中数 + - 底部"重置"和"筛选"按钮 + +- **右侧内容区**: + - 统计栏(已选形态数 + 匹配结果数) + - 结果表格(代码、名称、行业、涨跌幅、现价、成交额、总市值、换手率、量比) + - 表头可点击排序 + - 底部分页控件 + +### 交互流程 + +1. 页面加载 → GET `/api/pattern-screen/groups` → 渲染分组面板 +2. 勾选形态 → 点击"筛选" → POST `/api/pattern-screen/screen` → 更新表格 +3. 点击表头 → 重新筛选(带 sort_by 参数) +4. 点击页码 → 重新筛选(带 offset 参数) +5. 点击"重置" → 清空勾选,显示全部 + +### 技术栈 + +- Bootstrap 5(与项目一致) +- 原生 JavaScript(无额外框架) +- 项目已有 CSS 主题 (`financial-theme.css`) + +## 筛选逻辑 + +- **纯 AND**:所有勾选的形态字段必须同时为 1 +- 无勾选时(`patterns` 为空或省略)返回全部股票(仅排序和分页) +- 不在分组定义中的 DataFrame 列(如 `consec_up_days`)仍可作为 `sort_by` 使用,但不显示在筛选面板中 + +## 导航集成 + +在 `base.html` 导航栏中添加"形态选股"链接,使用 `url_for('pattern_screen.index')` 生成 URL。 + +## 不做的事情 + +- 不做形态回测(与现有 backtest 功能不同) +- 不做自然语言查询 +- 不做形态组合的 AND/OR 混合逻辑(纯 AND) +- 不修改现有选股模块 diff --git a/docs/superpowers/specs/2026-06-07-sector-heatmap-design.md b/docs/superpowers/specs/2026-06-07-sector-heatmap-design.md new file mode 100644 index 000000000..2076a4616 --- /dev/null +++ b/docs/superpowers/specs/2026-06-07-sector-heatmap-design.md @@ -0,0 +1,129 @@ +# 板块热力图功能设计 + +> Date: 2026-06-07 +> Status: Approved + +## 概述 + +新增 Treemap 热力图页面,展示 A 股 110 个行业板块的涨跌分布。板块面积 = 总市值占比,颜色 = 加权涨跌幅(红涨绿跌)。点击板块后页面内展开个股明细表格。 + +数据源:`data/data.parquet`(每日由数据下载任务更新,~5500 行,276 列)。 + +## 数据层 + +### 数据源 + +`data/data.parquet` — 每日全市场宽表快照,包含 industry、pct_chg、total_mv、net_mf_amount 等字段。 + +### 板块聚合逻辑(后端 Python) + +1. 读取 parquet,过滤掉 `close == 0` 的停牌/退市股 +2. 按 `industry` 分组,计算每个板块: + - `avg_pct_chg`:加权平均涨跌幅(权重 = `total_mv`) + - `total_mv`:板块总市值 + - `stock_count`:个股数量 + - `up_count` / `down_count`:涨/跌家数 + - `net_mf_amount`:板块主力净流入额合计 +3. 按板块总市值降序排列 + +### 个股明细 + +前端 JS 从全量数据中按 `industry` 过滤,展示字段: + +| 字段 | 说明 | +|---|---| +| `name` | 股票简称 | +| `pct_chg` | 涨跌幅 % | +| `close` | 收盘价 | +| `total_mv` | 总市值(亿元) | +| `net_mf_amount` | 主力净流入(万元) | +| `turnover_rate` | 换手率 % | + +按 `pct_chg` 降序排列。 + +### 数据传递 + +后端将两份数据注入 Jinja2 模板:`sectors_json`(板块聚合)和 `stocks_json`(全量个股)。前端 JS 直接使用,无需额外 API 调用。 + +## 前端页面结构 + +### 页面布局(从上到下) + +1. **页面标题栏**:板块热力图 · {trade_date},含排序选项和图例说明 +2. **Treemap 主区域**:ECharts treemap,每个矩形 = 一个行业板块 +3. **个股展开区域**:点击板块后动态展开/收起的表格 + +### 交互行为 + +- **点击板块矩形**:treemap 下方展开/切换个股表格(带折叠动画),再次点击同板块则收起 +- **悬停矩形**:ECharts tooltip 显示板块详情(涨跌家数、净流入、市值排名) +- **表格行点击**:不处理(保持简洁) + +### 配色(A 股惯例) + +- 涨:红色渐变 `#c0392b`(大涨)→ `#f5b7b1`(微涨) +- 跌:绿色渐变 `#27ae60`(大跌)→ `#a9dfbf`(微跌) +- 平盘:`#bdc3c7` 灰色 + +### ECharts Treemap 配置 + +- `visualMap` 连续型,范围取当日实际涨跌幅 min/max +- `leafDepth = 1`(只展示板块层级) +- `roam: false`(禁止缩放平移) + +## 文件结构与集成 + +### 新增文件(3 个) + +| 文件 | 用途 | +|---|---| +| `app/services/heatmap_service.py` | 数据聚合服务:读 parquet → 板块汇总 + 个股列表 | +| `app/templates/heatmap.html` | 页面模板:ECharts treemap + 个股展开表格 | +| `app/routes/heatmap.py` | 路由:`GET /heatmap` | + +### 修改文件(2 个) + +| 文件 | 改动 | +|---|---| +| `app/__init__.py` | 注册新 blueprint `heatmap_bp` | +| `app/templates/base.html` | 导航栏追加"板块热力图"链接 | + +### HeatmapService 接口 + +```python +class HeatmapService: + def get_heatmap_data(self) -> tuple[list[dict], list[dict]]: + """返回 (sectors_json, stocks_json)""" +``` + +### Route + +```python +heatmap_bp = Blueprint('heatmap', __name__) + +@heatmap_bp.route('/heatmap') +def heatmap_page(): + sectors, stocks = HeatmapService().get_heatmap_data() + return render_template('heatmap.html', + sectors_json=sectors, + stocks_json=stocks, + trade_date=sectors[0]['trade_date']) +``` + +### 前端 JS(内联,~150 行) + +- `initTreemap(sectors)` — 初始化 ECharts treemap +- `onSectorClick(industry)` — 过滤 stocks 数据,渲染/切换下方表格 +- 表格用原生 HTML `` + Bootstrap 样式 + +### 依赖 + +无新依赖。使用项目已有的 pandas、ECharts 5.4.3、Bootstrap 5.1.3、Jinja2。 + +## 验收标准 + +1. 访问 `/heatmap` 能看到板块 Treemap 热力图,颜色和面积正确 +2. 点击任意板块,下方展开该板块个股表格;再次点击收起 +3. 点击不同板块,表格切换为对应板块的个股 +4. 页面顶部导航有"板块热力图"入口 +5. `data/data.parquet` 更新后刷新页面即可看到新数据 diff --git a/run.py b/run.py index 0fbcd2c7d..c5a7c0d23 100644 --- a/run.py +++ b/run.py @@ -44,4 +44,4 @@ def inspect_runtime_health(flask_app): debug=False, use_reloader=False, allow_unsafe_werkzeug=True - ) + ) diff --git a/tests/api/test_pattern_screen_api.py b/tests/api/test_pattern_screen_api.py new file mode 100644 index 000000000..a4f76a0f2 --- /dev/null +++ b/tests/api/test_pattern_screen_api.py @@ -0,0 +1,98 @@ +"""Pattern screen API contract tests.""" +import pytest +from unittest.mock import patch, MagicMock + + +@pytest.fixture +def client(app): + return app.test_client() + + +@pytest.fixture +def mock_service(): + """Mock the singleton getter for API tests.""" + svc = MagicMock() + svc.get_groups.return_value = [ + { + "id": "trend_structure", + "label": "趋势结构", + "fields": [ + {"key": "pattern_golden_cross", "label": "均线金叉", "count": 304}, + {"key": "pattern_ma_bull", "label": "均线多头排列", "count": 259}, + ], + } + ] + svc.screen.return_value = { + "total": 2, + "offset": 0, + "limit": 50, + "trade_date": "20260605", + "rows": [ + {"ts_code": "000001.SZ", "name": "平安银行", "industry": "银行", + "pct_chg": 1.5, "close": 12.5, "amount": 1000000, + "total_mv": 2400000, "turnover_rate": 1.2, "vol_ratio_5": 1.8}, + {"ts_code": "000003.SZ", "name": "国农科技", "industry": "综合", + "pct_chg": 3.0, "close": 25.0, "amount": 200000, + "total_mv": 600000, "turnover_rate": 2.5, "vol_ratio_5": 2.1}, + ], + } + + patcher = patch( + 'app.services.pattern_screen_service.get_pattern_screen_service', + return_value=svc, + ) + patcher.start() + yield svc + patcher.stop() + + +class TestGetGroups: + def test_returns_200(self, client, mock_service): + resp = client.get('/api/pattern-screen/groups') + assert resp.status_code == 200 + data = resp.get_json() + assert data['code'] == 200 + assert isinstance(data['data'], list) + + def test_group_structure(self, client, mock_service): + resp = client.get('/api/pattern-screen/groups') + data = resp.get_json()['data'] + g = data[0] + assert 'id' in g + assert 'label' in g + assert 'fields' in g + + +class TestScreen: + def test_returns_200(self, client, mock_service): + resp = client.post('/api/pattern-screen/screen', + json={'patterns': ['pattern_golden_cross']}) + assert resp.status_code == 200 + data = resp.get_json() + assert data['code'] == 200 + assert data['data']['total'] == 2 + + def test_empty_patterns(self, client, mock_service): + resp = client.post('/api/pattern-screen/screen', json={}) + assert resp.status_code == 200 + + def test_invalid_sort_by_returns_400(self, client, mock_service): + mock_service.screen.side_effect = ValueError("sort_by 'bad' not in whitelist") + resp = client.post('/api/pattern-screen/screen', + json={'sort_by': 'bad'}) + assert resp.status_code == 400 + assert resp.get_json()['code'] == 400 + + def test_response_format(self, client, mock_service): + resp = client.post('/api/pattern-screen/screen', + json={'patterns': ['pattern_golden_cross']}) + data = resp.get_json()['data'] + assert 'total' in data + assert 'offset' in data + assert 'limit' in data + assert 'trade_date' in data + assert 'rows' in data + row = data['rows'][0] + for col in ['ts_code', 'name', 'industry', 'pct_chg', 'close', + 'amount', 'total_mv', 'turnover_rate', 'vol_ratio_5']: + assert col in row diff --git a/tests/services/test_heatmap_service.py b/tests/services/test_heatmap_service.py new file mode 100644 index 000000000..f0b37c6bd --- /dev/null +++ b/tests/services/test_heatmap_service.py @@ -0,0 +1,93 @@ +"""Tests for HeatmapService — sector aggregation logic.""" +import json +import pytest +from unittest.mock import patch, MagicMock +import pandas as pd +import numpy as np + + +@pytest.fixture +def sample_df(): + """Minimal DataFrame matching data/data.parquet schema.""" + return pd.DataFrame({ + 'ts_code': ['000001.SZ', '000002.SZ', '000003.SZ', '000004.SZ', '000005.SZ'], + 'name': ['平安银行', '万科A', '测试银行', '测试地产', '停牌股'], + 'industry': ['银行', '全国地产', '银行', '全国地产', '全国地产'], + 'pct_chg': [1.5, -2.0, 0.5, -1.0, 0.0], + 'close': [11.0, 3.5, 5.0, 8.0, 0.0], + 'total_mv': [21300, 3900, 5000, 2000, 100], + 'net_mf_amount': [14000, -5000, 3000, -2000, 0], + 'turnover_rate': [0.5, 1.7, 1.0, 2.0, 0.0], + 'trade_date': ['20260605'] * 5, + }) + + +@pytest.fixture +def service(): + from app.services.heatmap_service import HeatmapService + return HeatmapService() + + +class TestGetHeatmapData: + """Test HeatmapService.get_heatmap_data().""" + + @patch('app.services.heatmap_service.pd.read_parquet') + def test_returns_two_lists(self, mock_read, service, sample_df): + mock_read.return_value = sample_df + sectors, stocks = service.get_heatmap_data() + assert isinstance(sectors, list) + assert isinstance(stocks, list) + + @patch('app.services.heatmap_service.pd.read_parquet') + def test_filters_suspended_stocks(self, mock_read, service, sample_df): + """Stocks with close == 0 should be excluded.""" + mock_read.return_value = sample_df + sectors, stocks = service.get_heatmap_data() + stock_names = [s['name'] for s in stocks] + assert '停牌股' not in stock_names + + @patch('app.services.heatmap_service.pd.read_parquet') + def test_sector_count(self, mock_read, service, sample_df): + mock_read.return_value = sample_df + sectors, _ = service.get_heatmap_data() + industry_names = [s['name'] for s in sectors] + assert len(industry_names) == 2 # 银行 + 全国地产 + + @patch('app.services.heatmap_service.pd.read_parquet') + def test_sector_weighted_pct_chg(self, mock_read, service, sample_df): + """avg_pct_chg should be market-cap weighted average.""" + mock_read.return_value = sample_df + sectors, _ = service.get_heatmap_data() + bank = next(s for s in sectors if s['name'] == '银行') + # Weighted: (1.5*21300 + 0.5*5000) / (21300+5000) = 34450/26300 ≈ 1.3103 + assert abs(bank['avg_pct_chg'] - 1.3103) < 0.01 + + @patch('app.services.heatmap_service.pd.read_parquet') + def test_sector_stock_count(self, mock_read, service, sample_df): + mock_read.return_value = sample_df + sectors, _ = service.get_heatmap_data() + bank = next(s for s in sectors if s['name'] == '银行') + assert bank['stock_count'] == 2 # 2 银行 after filtering + + @patch('app.services.heatmap_service.pd.read_parquet') + def test_sector_up_down_count(self, mock_read, service, sample_df): + mock_read.return_value = sample_df + sectors, _ = service.get_heatmap_data() + realestate = next(s for s in sectors if s['name'] == '全国地产') + assert realestate['down_count'] == 2 + assert realestate['up_count'] == 0 + + @patch('app.services.heatmap_service.pd.read_parquet') + def test_stocks_have_required_fields(self, mock_read, service, sample_df): + mock_read.return_value = sample_df + _, stocks = service.get_heatmap_data() + required = {'name', 'ts_code', 'pct_chg', 'close', 'total_mv', + 'net_mf_amount', 'turnover_rate', 'industry'} + for s in stocks: + assert required.issubset(s.keys()), f"Missing keys: {required - s.keys()}" + + @patch('app.services.heatmap_service.pd.read_parquet') + def test_trade_date_returned(self, mock_read, service, sample_df): + mock_read.return_value = sample_df + sectors, _ = service.get_heatmap_data() + assert sectors[0]['trade_date'] == '20260605' diff --git a/tests/services/test_pattern_screen_service.py b/tests/services/test_pattern_screen_service.py new file mode 100644 index 000000000..5d284b1e2 --- /dev/null +++ b/tests/services/test_pattern_screen_service.py @@ -0,0 +1,140 @@ +"""PatternScreenService unit tests.""" +import pytest +from unittest.mock import patch, MagicMock +import pandas as pd +import numpy as np + + +@pytest.fixture +def sample_df(): + """Minimal DataFrame mimicking data/data.parquet structure.""" + return pd.DataFrame({ + 'ts_code': ['000001.SZ', '000002.SZ', '000003.SZ', '000004.SZ'], + 'name': ['平安银行', '万科A', '国农科技', '国信证券'], + 'industry': ['银行', '房地产', '综合', '证券'], + 'trade_date': ['20260605'] * 4, + 'pct_chg': [1.5, -0.5, 3.0, 0.0], + 'close': [12.5, 8.0, 25.0, 15.0], + 'amount': [1000000, 500000, 200000, 300000], + 'total_mv': [2400000, 1200000, 600000, 900000], + 'turnover_rate': [1.2, 0.8, 2.5, 0.5], + 'vol_ratio_5': [1.8, 0.6, 2.1, 0.9], + 'consec_up_days': [2, 0, 3, 1], + 'pattern_golden_cross': [1, 0, 1, 0], + 'pattern_ma_bull': [0, 0, 1, 1], + 'pattern_bull_candle': [1, 0, 1, 0], + 'pattern_bear_candle': [0, 1, 0, 1], + 'break_high_20': [1, 0, 1, 0], + }) + + +@pytest.fixture +def service(sample_df): + """Service with injected test DataFrame.""" + with patch('app.services.pattern_screen_service.PatternScreenService._load_df', return_value=sample_df): + from app.services.pattern_screen_service import PatternScreenService + svc = PatternScreenService() + svc._df = sample_df + return svc + + +class TestGetGroups: + def test_returns_groups_with_hits(self, service): + groups = service.get_groups() + assert isinstance(groups, list) + assert len(groups) > 0 + g = groups[0] + assert 'id' in g + assert 'label' in g + assert 'fields' in g + for f in g['fields']: + assert 'key' in f + assert 'label' in f + assert 'count' in f + assert isinstance(f['count'], int) + + def test_fields_not_in_dataframe_are_excluded(self, service): + groups = service.get_groups() + df_cols = set(service._df.columns) + for g in groups: + for f in g['fields']: + assert f['key'] in df_cols + + +class TestScreen: + def test_no_patterns_returns_all(self, service, sample_df): + result = service.screen(patterns=[]) + assert result['total'] == len(sample_df) + assert len(result['rows']) == len(sample_df) + + def test_single_pattern_filters(self, service): + result = service.screen(patterns=['pattern_golden_cross']) + assert result['total'] == 2 + codes = [r['ts_code'] for r in result['rows']] + assert '000001.SZ' in codes + assert '000003.SZ' in codes + + def test_multiple_patterns_and_logic(self, service): + result = service.screen(patterns=['pattern_golden_cross', 'pattern_ma_bull']) + assert result['total'] == 1 + assert result['rows'][0]['ts_code'] == '000003.SZ' + + def test_sort_desc(self, service): + result = service.screen(patterns=[], sort_by='pct_chg', order='desc') + pcts = [r['pct_chg'] for r in result['rows']] + assert pcts == sorted(pcts, reverse=True) + + def test_sort_asc(self, service): + result = service.screen(patterns=[], sort_by='pct_chg', order='asc') + pcts = [r['pct_chg'] for r in result['rows']] + assert pcts == sorted(pcts) + + def test_pagination(self, service): + result = service.screen(patterns=[], limit=2, offset=0) + assert len(result['rows']) == 2 + assert result['limit'] == 2 + assert result['offset'] == 0 + assert result['total'] == 4 + + def test_offset_beyond_results(self, service): + result = service.screen(patterns=[], offset=100) + assert result['total'] == 4 + assert len(result['rows']) == 0 + + def test_invalid_sort_by_raises(self, service): + with pytest.raises(ValueError, match='sort_by'): + service.screen(patterns=[], sort_by='invalid_column') + + def test_invalid_order_raises(self, service): + with pytest.raises(ValueError, match='order'): + service.screen(patterns=[], order='random') + + def test_invalid_pattern_key_raises(self, service): + with pytest.raises(ValueError, match='pattern.*not found'): + service.screen(patterns=['nonexistent_pattern']) + + def test_limit_capped_at_500(self, service): + result = service.screen(patterns=[], limit=9999) + assert result['limit'] == 500 + + def test_includes_trade_date(self, service): + result = service.screen(patterns=[]) + assert result['trade_date'] == '20260605' + + def test_nan_converted_to_none(self, sample_df): + sample_df.loc[0, 'industry'] = np.nan + with patch('app.services.pattern_screen_service.PatternScreenService._load_df', return_value=sample_df): + from app.services.pattern_screen_service import PatternScreenService + svc = PatternScreenService() + svc._df = sample_df + result = svc.screen(patterns=[]) + # Find the row with ts_code='000001.SZ' (index 0 in original df) + row = next(r for r in result['rows'] if r['ts_code'] == '000001.SZ') + assert row['industry'] is None + + +class TestInvalidateCache: + def test_clears_df(self, service): + assert service._df is not None + service.invalidate_cache() + assert service._df is None