diff --git a/deps.py b/deps.py index dd44be38..36db7f1a 100644 --- a/deps.py +++ b/deps.py @@ -10,6 +10,7 @@ from services.market_data_service import MarketDataService from services.trading_service import TradingService from services.unified_connector_service import UnifiedConnectorService +from services.backtesting_service import BacktestingService from services.websocket_manager import WebSocketManager from utils.bot_archiver import BotArchiver @@ -69,6 +70,11 @@ def get_executor_ws_manager(request: Request) -> ExecutorWebSocketManager: return request.app.state.executor_ws_manager +def get_backtesting_service(request: Request) -> BacktestingService: + """Get BacktestingService from app state.""" + return request.app.state.backtesting_service + + def get_websocket_manager(request: Request) -> WebSocketManager: """Get WebSocketManager from app state.""" return request.app.state.websocket_manager diff --git a/main.py b/main.py index 2a94c47f..d3c817d4 100644 --- a/main.py +++ b/main.py @@ -64,6 +64,7 @@ def patched_save_to_yml(yml_path, cm): from services.docker_service import DockerService # noqa: E402 from services.executor_service import ExecutorService # noqa: E402 from services.executor_ws_manager import ExecutorWebSocketManager # noqa: E402 +from services.backtesting_service import BacktestingService # noqa: E402 from services.gateway_service import GatewayService # noqa: E402 from services.market_data_service import MarketDataService # noqa: E402 from services.trading_service import TradingService # noqa: E402 @@ -226,6 +227,7 @@ async def lifespan(app: FastAPI): broker_password=settings.broker.password ) + backtesting_service = BacktestingService() docker_service = DockerService() gateway_service = GatewayService() bot_archiver = BotArchiver( @@ -264,6 +266,7 @@ async def lifespan(app: FastAPI): websocket_manager = WebSocketManager(market_data_service) app.state.websocket_manager = websocket_manager + app.state.backtesting_service = backtesting_service app.state.bots_orchestrator = bots_orchestrator app.state.docker_service = docker_service app.state.gateway_service = gateway_service diff --git a/routers/backtesting.py b/routers/backtesting.py index 3d68ee9b..f58f0f73 100644 --- a/routers/backtesting.py +++ b/routers/backtesting.py @@ -1,55 +1,60 @@ -from fastapi import APIRouter -from hummingbot.data_feed.candles_feed.candles_factory import CandlesFactory -from hummingbot.strategy_v2.backtesting.backtesting_engine_base import BacktestingEngineBase +from fastapi import APIRouter, Depends, HTTPException -from config import settings +from deps import get_backtesting_service from models.backtesting import BacktestingConfig +from services.backtesting_service import BacktestingService router = APIRouter(tags=["Backtesting"], prefix="/backtesting") -candles_factory = CandlesFactory() -backtesting_engine = BacktestingEngineBase() - - -@router.post("/run-backtesting") -async def run_backtesting(backtesting_config: BacktestingConfig): - """ - Run a backtesting simulation with the provided configuration. - - Args: - backtesting_config: Configuration for the backtesting including start/end time, - resolution, trade cost, and controller config - - Returns: - Dictionary containing executors, processed data, and results from the backtest - - Raises: - Returns error dictionary if backtesting fails - """ + + +@router.post("/run") +async def run_backtesting( + backtesting_config: BacktestingConfig, + service: BacktestingService = Depends(get_backtesting_service), +): + """Run a backtest synchronously. Returns results directly (may timeout for long backtests).""" try: - if isinstance(backtesting_config.config, str): - controller_config = backtesting_engine.get_controller_config_instance_from_yml( - config_path=backtesting_config.config, - controllers_conf_dir_path=settings.app.controllers_path, - controllers_module=settings.app.controllers_module - ) - else: - controller_config = backtesting_engine.get_controller_config_instance_from_dict( - config_data=backtesting_config.config, - controllers_module=settings.app.controllers_module - ) - backtesting_results = await backtesting_engine.run_backtesting( - controller_config=controller_config, trade_cost=backtesting_config.trade_cost, - start=int(backtesting_config.start_time), end=int(backtesting_config.end_time), - backtesting_resolution=backtesting_config.backtesting_resolution) - processed_data = backtesting_results["processed_data"]["features"].fillna(0) - executors_info = [e.to_dict() for e in backtesting_results["executors"]] - backtesting_results["processed_data"] = processed_data.to_dict() - results = backtesting_results["results"] - results["sharpe_ratio"] = results["sharpe_ratio"] if results["sharpe_ratio"] is not None else 0 - return { - "executors": executors_info, - "processed_data": backtesting_results["processed_data"], - "results": backtesting_results["results"], - } + return await service.run_backtest_sync(backtesting_config.model_dump()) except Exception as e: return {"error": str(e)} + + +@router.post("/tasks") +async def create_backtest_task( + backtesting_config: BacktestingConfig, + service: BacktestingService = Depends(get_backtesting_service), +): + """Submit a backtest as a background task. Returns task ID for polling.""" + task = service.submit_task(backtesting_config.model_dump()) + return {"task_id": task.task_id, "status": task.status.value} + + +@router.get("/tasks") +async def list_backtest_tasks( + service: BacktestingService = Depends(get_backtesting_service), +): + """List all backtest tasks with their status (results excluded for brevity).""" + return service.list_tasks() + + +@router.get("/tasks/{task_id}") +async def get_backtest_task( + task_id: str, + service: BacktestingService = Depends(get_backtesting_service), +): + """Get a backtest task by ID, including results if completed.""" + task = service.get_task(task_id) + if task is None: + raise HTTPException(status_code=404, detail=f"Task {task_id} not found") + return task.to_dict(include_result=True) + + +@router.delete("/tasks/{task_id}") +async def delete_backtest_task( + task_id: str, + service: BacktestingService = Depends(get_backtesting_service), +): + """Cancel a running task or remove a completed one.""" + if not service.cancel_task(task_id): + raise HTTPException(status_code=404, detail=f"Task {task_id} not found") + return {"status": "deleted", "task_id": task_id} diff --git a/services/backtesting_service.py b/services/backtesting_service.py new file mode 100644 index 00000000..24e7b31f --- /dev/null +++ b/services/backtesting_service.py @@ -0,0 +1,177 @@ +""" +BacktestingService manages background backtesting tasks. +Stores task state and results in memory for polling. +""" +import asyncio +import logging +import uuid +from datetime import datetime, timezone +from enum import Enum +from typing import Any, Dict, Optional + +from hummingbot.strategy_v2.backtesting.backtesting_engine_base import BacktestingEngineBase + +from config import settings + +logger = logging.getLogger(__name__) + + +class BacktestTaskStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class BacktestTask: + def __init__(self, task_id: str, config: dict): + self.task_id = task_id + self.config = config + self.status = BacktestTaskStatus.PENDING + self.created_at = datetime.now(timezone.utc) + self.started_at: Optional[datetime] = None + self.completed_at: Optional[datetime] = None + self.result: Optional[Dict[str, Any]] = None + self.error: Optional[str] = None + self._asyncio_task: Optional[asyncio.Task] = None + + def to_dict(self, include_result: bool = True) -> dict: + data = { + "task_id": self.task_id, + "status": self.status.value, + "config": self.config, + "created_at": self.created_at.isoformat(), + "started_at": self.started_at.isoformat() if self.started_at else None, + "completed_at": self.completed_at.isoformat() if self.completed_at else None, + "error": self.error, + } + if include_result and self.result is not None: + data["result"] = self.result + return data + + +class BacktestingService: + def __init__(self, max_tasks: int = 50): + self._tasks: Dict[str, BacktestTask] = {} + self._engine = BacktestingEngineBase() + self._max_tasks = max_tasks + + @property + def tasks(self) -> Dict[str, BacktestTask]: + return self._tasks + + def submit_task(self, config: dict) -> BacktestTask: + """Submit a new backtesting task to run in the background.""" + self._cleanup_old_tasks() + task_id = str(uuid.uuid4())[:8] + task = BacktestTask(task_id=task_id, config=config) + self._tasks[task_id] = task + task._asyncio_task = asyncio.create_task(self._run_task(task)) + logger.info(f"Backtesting task {task_id} submitted") + return task + + def get_task(self, task_id: str) -> Optional[BacktestTask]: + return self._tasks.get(task_id) + + def cancel_task(self, task_id: str) -> bool: + """Cancel a running task or remove a completed one.""" + task = self._tasks.get(task_id) + if task is None: + return False + if task._asyncio_task and not task._asyncio_task.done(): + task._asyncio_task.cancel() + task.status = BacktestTaskStatus.CANCELLED + task.completed_at = datetime.now(timezone.utc) + del self._tasks[task_id] + return True + + def list_tasks(self) -> list: + """List all tasks (without full results for brevity).""" + return [t.to_dict(include_result=False) for t in self._tasks.values()] + + async def run_backtest_sync(self, config: dict) -> dict: + """Run a backtest synchronously (returns full result directly).""" + return await self._execute_backtest(config) + + async def _run_task(self, task: BacktestTask): + """Background coroutine that executes the backtest.""" + task.status = BacktestTaskStatus.RUNNING + task.started_at = datetime.now(timezone.utc) + try: + task.result = await self._execute_backtest(task.config) + task.status = BacktestTaskStatus.COMPLETED + logger.info(f"Backtesting task {task.task_id} completed") + except asyncio.CancelledError: + task.status = BacktestTaskStatus.CANCELLED + logger.info(f"Backtesting task {task.task_id} cancelled") + except Exception as e: + task.status = BacktestTaskStatus.FAILED + task.error = str(e) + logger.error(f"Backtesting task {task.task_id} failed: {e}") + finally: + task.completed_at = datetime.now(timezone.utc) + + async def _execute_backtest(self, config: dict) -> dict: + """Core backtest execution logic shared by sync and async modes.""" + if isinstance(config["config"], str): + controller_config = self._engine.get_controller_config_instance_from_yml( + config_path=config["config"], + controllers_conf_dir_path=settings.app.controllers_path, + controllers_module=settings.app.controllers_module + ) + else: + controller_config = self._engine.get_controller_config_instance_from_dict( + config_data=config["config"], + controllers_module=settings.app.controllers_module + ) + backtesting_results = await self._engine.run_backtesting( + controller_config=controller_config, + trade_cost=config.get("trade_cost", 0.0006), + start=int(config["start_time"]), + end=int(config["end_time"]), + backtesting_resolution=config.get("backtesting_resolution", "1m"), + ) + processed_data = backtesting_results["processed_data"]["features"].fillna(0) + executors_info = [e.to_dict() for e in backtesting_results["executors"]] + results = backtesting_results["results"] + results["sharpe_ratio"] = results["sharpe_ratio"] if results["sharpe_ratio"] is not None else 0 + + # Serialize position holds + position_holds = [] + for ph in backtesting_results.get("position_holds", []): + position_holds.append({ + "connector_name": ph.connector_name, + "trading_pair": ph.trading_pair, + "buy_amount_base": float(ph.buy_amount_base), + "buy_amount_quote": float(ph.buy_amount_quote), + "sell_amount_base": float(ph.sell_amount_base), + "sell_amount_quote": float(ph.sell_amount_quote), + "net_amount_base": float(ph.net_amount_base), + "cum_fees_quote": float(ph.cum_fees_quote), + "volume_traded_quote": float(ph.volume_traded_quote), + "is_closed": ph.is_closed, + "n_executors": len(ph.source_executor_ids), + }) + + return { + "executors": executors_info, + "processed_data": processed_data.to_dict(), + "results": results, + "position_holds": position_holds, + "position_held_timeseries": backtesting_results.get("position_held_timeseries", []), + "pnl_timeseries": backtesting_results.get("pnl_timeseries", []), + } + + def _cleanup_old_tasks(self): + """Remove oldest completed/failed tasks if we exceed max_tasks.""" + if len(self._tasks) < self._max_tasks: + return + completed = [ + (tid, t) for tid, t in self._tasks.items() + if t.status in (BacktestTaskStatus.COMPLETED, BacktestTaskStatus.FAILED, BacktestTaskStatus.CANCELLED) + ] + completed.sort(key=lambda x: x[1].completed_at or x[1].created_at) + while len(self._tasks) >= self._max_tasks and completed: + tid, _ = completed.pop(0) + del self._tasks[tid]