diff --git a/webui/README.md b/webui/README.md index 571d93eb..6f2f8ee8 100644 --- a/webui/README.md +++ b/webui/README.md @@ -7,17 +7,21 @@ Web user interface for Kronos financial prediction model, providing intuitive gr - **Multi-format data support**: Supports CSV, Feather and other financial data formats - **Smart time window**: Fixed 400+120 data point time window slider selection - **Real model prediction**: Integrated real Kronos model, supports multiple model sizes +- **US stock mode**: Enter a Yahoo ticker (for example SMCI/AAPL/NVDA) and forecast future business days directly - **Prediction quality control**: Adjustable temperature, nucleus sampling, sample count and other parameters - **Multi-device support**: Supports CPU, CUDA, MPS and other computing devices - **Comparison analysis**: Detailed comparison between prediction results and actual data - **K-line chart display**: Professional financial K-line chart display +- **Probabilistic forecast chart**: Historical line + mean forecast + min-max range + volume panel ## πŸš€ Quick Start ### Method 1: Start with Python script ```bash cd webui -python run.py +uv venv +uv pip install -r requirements.txt +uv run run.py ``` ### Method 2: Start with Shell script @@ -30,7 +34,9 @@ chmod +x start.sh ### Method 3: Start Flask application directly ```bash cd webui -python app.py +uv venv +uv pip install -r requirements.txt +uv run app.py ``` After successful startup, visit http://localhost:7070 @@ -44,6 +50,18 @@ After successful startup, visit http://localhost:7070 5. **Start prediction**: Click prediction button to generate results 6. **View results**: View prediction results in charts and tables +### US Stock Prediction (Yahoo Finance) + +1. Load Kronos model first +2. Enter ticker symbol in the US stock section (for example `SMCI`) +3. Set prediction days and history period +4. Click **Predict US Stock (Yahoo)** +5. View probabilistic chart with: + - Historical close price + - Mean forecast path + - Forecast uncertainty band (min-max) + - Historical and forecasted volume + ## πŸ”§ Prediction Quality Parameters ### Temperature (T) @@ -112,7 +130,7 @@ The system automatically provides comparison analysis between prediction results ### Common Issues 1. **Port occupied**: Modify port number in app.py -2. **Missing dependencies**: Run `pip install -r requirements.txt` +2. **Missing dependencies**: Run `uv pip install -r requirements.txt` 3. **Model loading failed**: Check network connection and model ID 4. **Data format error**: Ensure data column names and format are correct diff --git a/webui/app.py b/webui/app.py index d240a372..1a71e7d5 100644 --- a/webui/app.py +++ b/webui/app.py @@ -4,22 +4,33 @@ import json import plotly.graph_objects as go import plotly.utils +from plotly.subplots import make_subplots from flask import Flask, render_template, request, jsonify from flask_cors import CORS import sys import warnings import datetime +from datetime import timedelta warnings.filterwarnings('ignore') +try: + import yfinance as yf + YFINANCE_AVAILABLE = True +except ImportError: + YFINANCE_AVAILABLE = False + # Add project root directory to path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) try: from model import Kronos, KronosTokenizer, KronosPredictor MODEL_AVAILABLE = True -except ImportError: + MODEL_IMPORT_ERROR = None +except ImportError as e: MODEL_AVAILABLE = False - print("Warning: Kronos model cannot be imported, will use simulated data for demonstration") + MODEL_IMPORT_ERROR = str(e) + print(f"Warning: Kronos model cannot be imported: {MODEL_IMPORT_ERROR}") + print("Model features are disabled until dependencies are installed.") app = Flask(__name__) CORS(app) @@ -122,6 +133,213 @@ def load_data_file(file_path): except Exception as e: return None, f"Failed to load file: {str(e)}" + +def load_us_daily(symbol, min_rows=520, history_period='5y'): + """Load US stock daily bars from Yahoo Finance.""" + if not YFINANCE_AVAILABLE: + return None, "yfinance is not installed. Install it with: uv pip install yfinance" + + try: + df = yf.download(symbol, period=history_period, interval='1d', auto_adjust=False, progress=False) + except Exception as e: + return None, f"Failed to download {symbol}: {str(e)}" + + if df is None or df.empty: + return None, f"No data returned for {symbol} from Yahoo Finance" + + # yfinance can return MultiIndex columns depending on version/options. + if isinstance(df.columns, pd.MultiIndex): + df.columns = [str(col[0]) for col in df.columns] + + df = df.reset_index().rename( + columns={ + 'Date': 'timestamps', + 'Open': 'open', + 'High': 'high', + 'Low': 'low', + 'Close': 'close', + 'Volume': 'volume', + } + ) + + expected = ['timestamps', 'open', 'high', 'low', 'close', 'volume'] + missing = [col for col in expected if col not in df.columns] + if missing: + return None, f"Yahoo Finance response missing columns: {missing}" + + df['timestamps'] = pd.to_datetime(df['timestamps']).dt.tz_localize(None) + for col in ['open', 'high', 'low', 'close', 'volume']: + df[col] = pd.to_numeric(df[col], errors='coerce') + df['amount'] = df['close'] * df['volume'] + + df = df[['timestamps', 'open', 'high', 'low', 'close', 'volume', 'amount']] + df = df.dropna().sort_values('timestamps').reset_index(drop=True) + + if len(df) < min_rows: + return None, f"Not enough rows for robust context. Need >= {min_rows}, got {len(df)} for {symbol}." + + return df, None + + +def future_business_days(last_day, periods): + start = last_day + timedelta(days=1) + return pd.Series(pd.bdate_range(start=start, periods=periods)) + + +def generate_probabilistic_forecast(x_df, x_timestamp, y_timestamp, pred_len, T, top_p, scenario_count): + """Run multiple stochastic forecasts to obtain mean/min/max paths.""" + forecast_cols = ['open', 'high', 'low', 'close', 'volume', 'amount'] + scenario_frames = [] + + for _ in range(scenario_count): + pred = predictor.predict( + df=x_df, + x_timestamp=x_timestamp, + y_timestamp=y_timestamp, + pred_len=pred_len, + T=T, + top_p=top_p, + sample_count=1, + verbose=False, + ) + scenario_frames.append(pred[forecast_cols].reset_index(drop=True)) + + stacked = np.stack([frame.values for frame in scenario_frames], axis=0) + mean_vals = stacked.mean(axis=0) + min_vals = stacked.min(axis=0) + max_vals = stacked.max(axis=0) + + mean_df = pd.DataFrame(mean_vals, columns=forecast_cols, index=y_timestamp) + min_df = pd.DataFrame(min_vals, columns=forecast_cols, index=y_timestamp) + max_df = pd.DataFrame(max_vals, columns=forecast_cols, index=y_timestamp) + + return mean_df, min_df, max_df, scenario_frames + + +def create_us_stock_prob_chart(hist_df, mean_df, min_df, max_df, symbol): + """Create a two-panel probabilistic chart: price + volume.""" + fig = make_subplots( + rows=2, + cols=1, + shared_xaxes=True, + vertical_spacing=0.06, + row_heights=[0.72, 0.28], + subplot_titles=( + f'{symbol.upper()} Probabilistic Price & Volume Forecast', + 'Volume', + ), + ) + + history_x = hist_df['timestamps'] + forecast_x = mean_df.index + + fig.add_trace( + go.Scatter( + x=history_x, + y=hist_df['close'], + mode='lines', + name='Historical Price', + line=dict(color='royalblue', width=2), + ), + row=1, + col=1, + ) + + # Draw range as an envelope between min and max close forecasts. + fig.add_trace( + go.Scatter( + x=forecast_x, + y=max_df['close'], + mode='lines', + line=dict(color='rgba(255, 159, 64, 0)'), + showlegend=False, + hoverinfo='skip', + ), + row=1, + col=1, + ) + + fig.add_trace( + go.Scatter( + x=forecast_x, + y=min_df['close'], + mode='lines', + fill='tonexty', + fillcolor='rgba(255, 159, 64, 0.25)', + line=dict(color='rgba(255, 159, 64, 0)'), + name='Forecast Range (Min-Max)', + hoverinfo='skip', + ), + row=1, + col=1, + ) + + mean_customdata = np.column_stack([min_df['close'].values, max_df['close'].values]) + fig.add_trace( + go.Scatter( + x=forecast_x, + y=mean_df['close'], + mode='lines', + name='Mean Forecast', + line=dict(color='darkorange', width=2), + customdata=mean_customdata, + hovertemplate=( + '%{x|%b %-d, %Y}
' + 'Mean Forecast: %{y:.4f}
' + 'Min Forecast: %{customdata[0]:.4f}
' + 'Max Forecast: %{customdata[1]:.4f}' + '' + ), + ), + row=1, + col=1, + ) + + fig.add_vline( + x=forecast_x[0], + line_color='red', + line_dash='dash', + line_width=2, + ) + + fig.add_trace( + go.Bar( + x=history_x, + y=hist_df['volume'], + name='Historical Volume', + marker_color='lightskyblue', + opacity=0.8, + ), + row=2, + col=1, + ) + + fig.add_trace( + go.Bar( + x=forecast_x, + y=mean_df['volume'], + name='Mean Forecasted Volume', + marker_color='sandybrown', + opacity=0.9, + ), + row=2, + col=1, + ) + + fig.update_layout( + template='plotly_white', + height=760, + hovermode='x unified', + legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='left', x=0.01), + margin=dict(l=40, r=30, t=90, b=40), + ) + + fig.update_xaxes(rangeslider_visible=False) + fig.update_yaxes(title_text='Price', row=1, col=1) + fig.update_yaxes(title_text='Volume', row=2, col=1) + + return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) + def save_prediction_results(file_path, prediction_type, prediction_results, actual_data, input_data, prediction_params): """Save prediction results to file""" try: @@ -623,6 +841,89 @@ def predict(): except Exception as e: return jsonify({'error': f'Prediction failed: {str(e)}'}), 500 + +@app.route('/api/predict-us-stock', methods=['POST']) +def predict_us_stock(): + """Predict US stock daily bars and return a probabilistic chart.""" + try: + if predictor is None: + return jsonify({'error': 'Kronos model not loaded, please load model first'}), 400 + + data = request.get_json() or {} + symbol = str(data.get('symbol', 'F')).strip().upper() + pred_len = int(data.get('pred_len', 5)) + lookback = int(data.get('lookback', 400)) + history_period = str(data.get('history_period', '5y')).strip() + + temperature = float(data.get('temperature', 1.0)) + top_p = float(data.get('top_p', 0.9)) + sample_count = max(1, int(data.get('sample_count', 3))) + scenario_count = max(3, int(data.get('scenario_count', sample_count))) + + if pred_len < 1: + return jsonify({'error': 'pred_len must be >= 1'}), 400 + if lookback < 60: + return jsonify({'error': 'lookback must be >= 60'}), 400 + + df, error = load_us_daily(symbol, min_rows=max(lookback + 5, 120), history_period=history_period) + if error: + return jsonify({'error': error}), 400 + + hist = df.tail(lookback).copy() + feature_cols = ['open', 'high', 'low', 'close', 'volume', 'amount'] + x_df = hist[feature_cols] + x_timestamp = hist['timestamps'] + y_timestamp = future_business_days(hist['timestamps'].iloc[-1], pred_len) + + mean_df, min_df, max_df, _ = generate_probabilistic_forecast( + x_df=x_df, + x_timestamp=x_timestamp, + y_timestamp=y_timestamp, + pred_len=pred_len, + T=temperature, + top_p=top_p, + scenario_count=scenario_count, + ) + + chart_json = create_us_stock_prob_chart(hist, mean_df, min_df, max_df, symbol) + + prediction_results = [] + for ts in mean_df.index: + prediction_results.append({ + 'timestamp': pd.Timestamp(ts).isoformat(), + 'open': float(mean_df.loc[ts, 'open']), + 'high': float(mean_df.loc[ts, 'high']), + 'low': float(mean_df.loc[ts, 'low']), + 'close': float(mean_df.loc[ts, 'close']), + 'volume': float(mean_df.loc[ts, 'volume']), + 'amount': float(mean_df.loc[ts, 'amount']), + 'close_min': float(min_df.loc[ts, 'close']), + 'close_max': float(max_df.loc[ts, 'close']), + }) + + return jsonify({ + 'success': True, + 'mode': 'us_stock', + 'symbol': symbol, + 'chart': chart_json, + 'prediction_results': prediction_results, + 'last_observed': { + 'timestamp': hist['timestamps'].iloc[-1].isoformat(), + 'open': float(hist['open'].iloc[-1]), + 'high': float(hist['high'].iloc[-1]), + 'low': float(hist['low'].iloc[-1]), + 'close': float(hist['close'].iloc[-1]), + 'volume': float(hist['volume'].iloc[-1]), + 'amount': float(hist['amount'].iloc[-1]), + }, + 'message': ( + f'US stock prediction completed for {symbol}: ' + f'{pred_len} business days, {scenario_count} stochastic scenarios.' + ), + }) + except Exception as e: + return jsonify({'error': f'US stock prediction failed: {str(e)}'}), 500 + @app.route('/api/load-model', methods=['POST']) def load_model(): """Load Kronos model""" @@ -667,7 +968,8 @@ def get_available_models(): """Get available model list""" return jsonify({ 'models': AVAILABLE_MODELS, - 'model_available': MODEL_AVAILABLE + 'model_available': MODEL_AVAILABLE, + 'model_import_error': MODEL_IMPORT_ERROR }) @app.route('/api/model-status') diff --git a/webui/requirements.txt b/webui/requirements.txt index 8a47f81c..69ceb852 100644 --- a/webui/requirements.txt +++ b/webui/requirements.txt @@ -5,3 +5,7 @@ numpy>=1.26.0 plotly==5.17.0 torch>=2.1.0 huggingface_hub==0.33.1 +einops==0.8.1 +tqdm==4.67.1 +safetensors==0.6.2 +yfinance>=0.2.40 diff --git a/webui/templates/index.html b/webui/templates/index.html index dd24a49e..368f4753 100644 --- a/webui/templates/index.html +++ b/webui/templates/index.html @@ -494,6 +494,37 @@

🎯 Control Panel

πŸ“ Load Data +
+ + +
+ + + Fetches daily data from Yahoo Finance +
+ +
+ + + Number of future business days to forecast +
+ +
+ + + Historical range pulled from Yahoo Finance +
+ + +