88import secrets
99from uuid import UUID
1010
11- import httpx
1211import pandas as pd
1312from futuresearch .api_utils import handle_response
1413from futuresearch .generated .api .tasks import get_task_status_tasks_task_id_status_get
1817
1918from futuresearch_mcp import redis_store
2019from futuresearch_mcp .config import settings
21- from futuresearch_mcp .tool_helpers import _UI_EXCLUDE , TaskState
20+ from futuresearch_mcp .result_store import _sanitize_records
21+ from futuresearch_mcp .tool_helpers import _UI_EXCLUDE , TaskState , _fetch_task_result
2222
2323logger = logging .getLogger (__name__ )
2424
@@ -81,6 +81,61 @@ async def _validate_poll_token(task_id: str, request: Request) -> JSONResponse |
8181 return None
8282
8383
84+ async def _fetch_summaries_rest (
85+ client : AuthenticatedClient , task_id : str , cursor : str | None
86+ ) -> tuple [list [dict ] | None , str | None ]:
87+ """Fetch agent summaries from the Engine API for the REST progress endpoint."""
88+ try :
89+ params : dict [str , str ] = {}
90+ if cursor :
91+ params ["cursor" ] = cursor
92+ httpx_client = client .get_async_httpx_client ()
93+ resp = await httpx_client .request (
94+ method = "get" ,
95+ url = f"/tasks/{ task_id } /summaries" ,
96+ params = params ,
97+ )
98+ if resp .status_code == 200 :
99+ data = resp .json ()
100+ return data .get ("summaries" ) or None , data .get ("cursor" ) or cursor
101+ except Exception :
102+ logger .debug ("Failed to fetch summaries for task %s via REST" , task_id )
103+ return None , cursor
104+
105+
106+ async def _fetch_aggregate_rest (
107+ client : AuthenticatedClient , task_id : str , cursor : str | None
108+ ) -> tuple [str | None , list [dict ] | None , str | None ]:
109+ """Fetch aggregate + micro-summaries from the Engine API.
110+
111+ Returns (aggregate_text, micro_summaries, updated_cursor).
112+ Falls back to plain summaries when the aggregate endpoint is unavailable.
113+ """
114+ try :
115+ params : dict [str , str ] = {}
116+ if cursor :
117+ params ["cursor" ] = cursor
118+ httpx_client = client .get_async_httpx_client ()
119+ resp = await httpx_client .request (
120+ method = "get" ,
121+ url = f"/tasks/{ task_id } /summaries/aggregate" ,
122+ params = params ,
123+ )
124+ if resp .status_code == 200 :
125+ data = resp .json ()
126+ return (
127+ data .get ("aggregate" ) or None ,
128+ data .get ("micro_summaries" ) or None ,
129+ data .get ("cursor" ) or cursor ,
130+ )
131+ except Exception :
132+ pass
133+
134+ # Fallback: plain summaries without aggregate
135+ summaries , new_cursor = await _fetch_summaries_rest (client , task_id , cursor )
136+ return None , summaries , new_cursor
137+
138+
84139async def api_progress (request : Request ) -> Response :
85140 """REST endpoint for the session widget to poll task progress."""
86141 cors = _cors_headers ()
@@ -119,12 +174,28 @@ async def api_progress(request: Request) -> Response:
119174
120175 ts = TaskState (status_response )
121176
122- # Don't pop the token on completion — the download route needs it.
123- # Let the Redis TTL expire it naturally.
177+ if ts .is_terminal :
178+ # Don't pop the token immediately — the widget's autoFetchResults
179+ # needs it to call /download-token after task completion.
180+ # The token will expire naturally via Redis TTL.
181+ pass
124182
125- return JSONResponse (
126- ts .model_dump (mode = "json" , exclude = _UI_EXCLUDE ), headers = cors
127- )
183+ payload = ts .model_dump (mode = "json" , exclude = _UI_EXCLUDE )
184+
185+ # Fetch aggregate + micro-summaries + partial rows for non-terminal tasks
186+ if not ts .is_terminal :
187+ cursor = request .query_params .get ("cursor" )
188+ aggregate , summaries , new_cursor = await _fetch_aggregate_rest (
189+ client , task_id , cursor
190+ )
191+ if aggregate :
192+ payload ["aggregate_summary" ] = aggregate
193+ if summaries :
194+ payload ["summaries" ] = summaries
195+ if new_cursor :
196+ payload ["cursor" ] = new_cursor
197+
198+ return JSONResponse (payload , headers = cors )
128199 except Exception as exc :
129200 logger .error (
130201 "Progress poll failed for task %s: %s" , task_id , type (exc ).__name__
@@ -153,12 +224,11 @@ async def _validate_poll_token_bearer_only(
153224 return None
154225
155226
156- async def api_download_url (request : Request ) -> Response :
157- """Return the download URL for a task .
227+ async def api_download (request : Request ) -> Response : # noqa: PLR0911
228+ """REST endpoint to download task results as CSV or JSON .
158229
159- The widget calls this to get the download URL. Validates the poll
160- token so only the session owner gets the URL (the download itself
161- is open by task ID).
230+ Authenticates via the poll token (Authorization: Bearer header or
231+ ?token= query param). No separate download token needed.
162232 """
163233 cors = _cors_headers ()
164234 if request .method == "OPTIONS" :
@@ -172,72 +242,38 @@ async def api_download_url(request: Request) -> Response:
172242 if err := _validate_uuid (task_id ):
173243 return err
174244
175- if err := await _validate_poll_token_bearer_only (task_id , request ):
176- return err
177-
178- download_url = f"{ settings .mcp_server_url } /api/results/{ task_id } /download"
179- return JSONResponse ({"download_url" : download_url }, headers = cors )
180-
181-
182- async def api_download (request : Request ) -> Response : # noqa: PLR0911
183- """Download task results as CSV or JSON.
184-
185- Fetches results from the public Engine API using the per-task API key
186- stored in Redis.
187-
188- Query params:
189- format: "csv" (default) or "json"
190- """
191- cors = _cors_headers ()
192- if request .method == "OPTIONS" :
193- return Response (
194- status_code = 204 ,
195- headers = {** cors , "Access-Control-Max-Age" : "3600" },
196- )
197-
198- task_id = request .path_params ["task_id" ]
199-
200- if err := _validate_uuid (task_id ):
245+ if err := await _validate_poll_token (task_id , request ):
201246 return err
202247
203248 fmt = request .query_params .get ("format" , "csv" )
204249 if fmt not in ("csv" , "json" ):
205250 return JSONResponse (
206251 {"error" : "Unsupported format" }, status_code = 400 , headers = cors
207252 )
208- # Fetch results via the public API (paginated path handles citation
253+
254+ # Fetch results via the public API (parquet-first path handles citation
209255 # resolution and internal column stripping automatically).
210256 api_key = await redis_store .get_task_token (task_id )
211257 if not api_key :
212258 return JSONResponse (
213- {"error" : "Unknown task or expired session" },
214- status_code = 404 ,
215- headers = cors ,
259+ {"error" : "Results not found or expired" }, status_code = 404 , headers = cors
216260 )
217-
218261 try :
219- # Trailing slash on base_url is required for httpx to append
220- # relative paths correctly (RFC 3986). A leading slash on the
221- # request path would *replace* the base path, sending the
222- # request to https://host/tasks/… instead of …/api/v0/tasks/….
223- base = settings .futuresearch_api_url .rstrip ("/" ) + "/"
224- async with httpx .AsyncClient (
225- base_url = base ,
226- headers = {"Authorization" : f"Bearer { api_key } " },
227- ) as http :
228- resp = await http .get (
229- f"tasks/{ task_id } /result" ,
230- params = {"offset" : 0 , "limit" : 100000 },
231- )
232- resp .raise_for_status ()
233- body = resp .json ()
234- records : list [dict ] = body .get ("data" ) or []
262+ client = AuthenticatedClient (
263+ base_url = settings .futuresearch_api_url ,
264+ token = api_key ,
265+ raise_on_unexpected_status = True ,
266+ follow_redirects = True ,
267+ )
268+ rows , _total , _session_id , _artifact_id = await _fetch_task_result (
269+ client , task_id
270+ )
271+ records : list [dict ] = _sanitize_records (rows )
235272 except Exception :
236- logger .exception ("Failed to fetch results for download, task %s" , task_id )
273+ logger .warning ("Failed to fetch results for task %s" , task_id , exc_info = True )
237274 return JSONResponse (
238- {"error" : "Failed to fetch results " }, status_code = 500 , headers = cors
275+ {"error" : "Results not found or expired " }, status_code = 404 , headers = cors
239276 )
240-
241277 safe_prefix = "" .join (c for c in task_id [:8 ] if c .isalnum () or c == "-" )
242278
243279 if fmt == "json" :
@@ -251,6 +287,7 @@ async def api_download(request: Request) -> Response: # noqa: PLR0911
251287 },
252288 )
253289
290+ # CSV generated on-the-fly from the already-resolved records.
254291 csv_text = pd .DataFrame (records ).to_csv (index = False , quoting = csv .QUOTE_ALL )
255292 return Response (
256293 content = csv_text ,
0 commit comments