Skip to content

Commit 3f23013

Browse files
rgambeegithub-actions[bot]
authored andcommitted
API endpoint for task cost (#5377)
This adds a new API endpoint for fetching the cost of a particular task: `tasks/<task-id>/cost`. The feature is wired though the SDK to the MCP server. The CC system prompt encourages the agent to use this tool instead of estimating costs. I made some other adjustments to the system prompt to disambiguate the user's account balance from the conversation budget. Sourced from commit 58494fdf34b17e2e639777ab6c34d701c4f69040
1 parent da8f557 commit 3f23013

14 files changed

Lines changed: 405 additions & 7 deletions

futuresearch-mcp/manifest.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@
9292
"name": "futuresearch_balance",
9393
"description": "Check the current billing balance for the authenticated user."
9494
},
95+
{
96+
"name": "futuresearch_task_cost",
97+
"description": "Get the billed cost of a completed task."
98+
},
9599
{
96100
"name": "futuresearch_browse_lists",
97101
"description": "Browse available reference lists of well-known entities."

futuresearch-mcp/src/futuresearch_mcp/models.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,19 @@ def validate_task_id(cls, v: str) -> str:
637637
return _validate_task_id(v)
638638

639639

640+
class TaskCostInput(BaseModel):
641+
"""Input for getting the billed cost of a task."""
642+
643+
model_config = ConfigDict(str_strip_whitespace=True, extra="forbid")
644+
645+
task_id: str = Field(..., description="The task ID to check the cost for.")
646+
647+
@field_validator("task_id")
648+
@classmethod
649+
def validate_task_id(cls, v: str) -> str:
650+
return _validate_task_id(v)
651+
652+
640653
def _validate_output_path(v: str | None) -> str | None:
641654
"""Validate output_path ends in .csv and parent directory exists."""
642655
if v is not None:

futuresearch-mcp/src/futuresearch_mcp/tools.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from futuresearch.constants import FuturesearchError as EveryrowError
1616
from futuresearch.generated.api.billing import get_billing_balance_billing_get
1717
from futuresearch.generated.api.tasks import get_task_status_tasks_task_id_status_get
18+
from futuresearch.generated.models.task_cost_status import TaskCostStatus
1819
from futuresearch.generated.models.task_status import TaskStatus
1920
from futuresearch.ops import (
2021
_submit_agent_map,
@@ -27,7 +28,7 @@
2728
merge_async,
2829
)
2930
from futuresearch.session import list_sessions
30-
from futuresearch.task import cancel_task
31+
from futuresearch.task import cancel_task, get_task_cost
3132
from mcp.types import CallToolResult, TextContent, ToolAnnotations
3233
from pydantic import BaseModel, create_model
3334

@@ -49,6 +50,7 @@
4950
RankInput,
5051
SingleAgentInput,
5152
StdioResultsInput,
53+
TaskCostInput,
5254
UploadDataInput,
5355
UseListInput,
5456
)
@@ -1337,6 +1339,58 @@ async def futuresearch_balance(ctx: FuturesearchContext) -> list[TextContent]:
13371339
]
13381340

13391341

1342+
@mcp.tool(
1343+
name="futuresearch_task_cost",
1344+
structured_output=False,
1345+
annotations=ToolAnnotations(
1346+
title="Get Task Cost",
1347+
readOnlyHint=True,
1348+
destructiveHint=False,
1349+
idempotentHint=True,
1350+
openWorldHint=False,
1351+
),
1352+
)
1353+
async def futuresearch_task_cost(
1354+
params: TaskCostInput, ctx: FuturesearchContext
1355+
) -> list[TextContent]:
1356+
"""Get the billed cost of a completed task.
1357+
1358+
Returns the amount charged in dollars. There is a delay between task completion and
1359+
cost calculation. Returns 'pending' if the cost hasn't settled yet.
1360+
"""
1361+
logger.info("futuresearch_task_cost: task_id=%s", params.task_id)
1362+
client = _get_client(ctx)
1363+
1364+
try:
1365+
response = await get_task_cost(
1366+
task_id=UUID(params.task_id),
1367+
client=client,
1368+
)
1369+
except Exception:
1370+
logger.exception("Failed to get task cost for task %s", params.task_id)
1371+
return [
1372+
TextContent(
1373+
type="text",
1374+
text=f"Error retrieving cost for task {params.task_id}. Please try again.",
1375+
)
1376+
]
1377+
1378+
if response.status == TaskCostStatus.PENDING:
1379+
return [
1380+
TextContent(
1381+
type="text",
1382+
text=f"Cost for task {params.task_id} is still being calculated. Try again in ~30 seconds.",
1383+
)
1384+
]
1385+
1386+
return [
1387+
TextContent(
1388+
type="text",
1389+
text=f"Task {params.task_id} cost: ${response.cost_dollars:.2f}",
1390+
)
1391+
]
1392+
1393+
13401394
@mcp.tool(
13411395
name="futuresearch_list_session_tasks",
13421396
structured_output=False,

futuresearch-mcp/tests/test_mcp_e2e.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,10 @@ async def test_list_tools(self, _http_state):
168168
"futuresearch_merge",
169169
"futuresearch_progress",
170170
"futuresearch_rank",
171-
"futuresearch_status",
172171
"futuresearch_results",
173172
"futuresearch_single_agent",
173+
"futuresearch_status",
174+
"futuresearch_task_cost",
174175
"futuresearch_upload_data",
175176
"futuresearch_use_list",
176177
]
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
from http import HTTPStatus
2+
from typing import Any
3+
from urllib.parse import quote
4+
from uuid import UUID
5+
6+
import httpx
7+
8+
from ... import errors
9+
from ...client import AuthenticatedClient, Client
10+
from ...models.error_response import ErrorResponse
11+
from ...models.http_validation_error import HTTPValidationError
12+
from ...models.task_cost_response import TaskCostResponse
13+
from ...types import Response
14+
15+
16+
def _get_kwargs(
17+
task_id: UUID,
18+
) -> dict[str, Any]:
19+
_kwargs: dict[str, Any] = {
20+
"method": "get",
21+
"url": "/tasks/{task_id}/cost".format(
22+
task_id=quote(str(task_id), safe=""),
23+
),
24+
}
25+
26+
return _kwargs
27+
28+
29+
def _parse_response(
30+
*, client: AuthenticatedClient | Client, response: httpx.Response
31+
) -> ErrorResponse | HTTPValidationError | TaskCostResponse | None:
32+
if response.status_code == 200:
33+
response_200 = TaskCostResponse.from_dict(response.json())
34+
35+
return response_200
36+
37+
if response.status_code == 404:
38+
response_404 = ErrorResponse.from_dict(response.json())
39+
40+
return response_404
41+
42+
if response.status_code == 422:
43+
response_422 = HTTPValidationError.from_dict(response.json())
44+
45+
return response_422
46+
47+
if client.raise_on_unexpected_status:
48+
raise errors.UnexpectedStatus(response.status_code, response.content)
49+
else:
50+
return None
51+
52+
53+
def _build_response(
54+
*, client: AuthenticatedClient | Client, response: httpx.Response
55+
) -> Response[ErrorResponse | HTTPValidationError | TaskCostResponse]:
56+
return Response(
57+
status_code=HTTPStatus(response.status_code),
58+
content=response.content,
59+
headers=response.headers,
60+
parsed=_parse_response(client=client, response=response),
61+
)
62+
63+
64+
def sync_detailed(
65+
task_id: UUID,
66+
*,
67+
client: AuthenticatedClient,
68+
) -> Response[ErrorResponse | HTTPValidationError | TaskCostResponse]:
69+
"""Get task cost
70+
71+
Get the billed cost of a task. Returns status 'pending' until the cost has been calculated, then
72+
'settled' with the final cost in dollars.
73+
74+
Args:
75+
task_id (UUID):
76+
77+
Raises:
78+
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
79+
httpx.TimeoutException: If the request takes longer than Client.timeout.
80+
81+
Returns:
82+
Response[ErrorResponse | HTTPValidationError | TaskCostResponse]
83+
"""
84+
85+
kwargs = _get_kwargs(
86+
task_id=task_id,
87+
)
88+
89+
response = client.get_httpx_client().request(
90+
**kwargs,
91+
)
92+
93+
return _build_response(client=client, response=response)
94+
95+
96+
def sync(
97+
task_id: UUID,
98+
*,
99+
client: AuthenticatedClient,
100+
) -> ErrorResponse | HTTPValidationError | TaskCostResponse | None:
101+
"""Get task cost
102+
103+
Get the billed cost of a task. Returns status 'pending' until the cost has been calculated, then
104+
'settled' with the final cost in dollars.
105+
106+
Args:
107+
task_id (UUID):
108+
109+
Raises:
110+
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
111+
httpx.TimeoutException: If the request takes longer than Client.timeout.
112+
113+
Returns:
114+
ErrorResponse | HTTPValidationError | TaskCostResponse
115+
"""
116+
117+
return sync_detailed(
118+
task_id=task_id,
119+
client=client,
120+
).parsed
121+
122+
123+
async def asyncio_detailed(
124+
task_id: UUID,
125+
*,
126+
client: AuthenticatedClient,
127+
) -> Response[ErrorResponse | HTTPValidationError | TaskCostResponse]:
128+
"""Get task cost
129+
130+
Get the billed cost of a task. Returns status 'pending' until the cost has been calculated, then
131+
'settled' with the final cost in dollars.
132+
133+
Args:
134+
task_id (UUID):
135+
136+
Raises:
137+
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
138+
httpx.TimeoutException: If the request takes longer than Client.timeout.
139+
140+
Returns:
141+
Response[ErrorResponse | HTTPValidationError | TaskCostResponse]
142+
"""
143+
144+
kwargs = _get_kwargs(
145+
task_id=task_id,
146+
)
147+
148+
response = await client.get_async_httpx_client().request(**kwargs)
149+
150+
return _build_response(client=client, response=response)
151+
152+
153+
async def asyncio(
154+
task_id: UUID,
155+
*,
156+
client: AuthenticatedClient,
157+
) -> ErrorResponse | HTTPValidationError | TaskCostResponse | None:
158+
"""Get task cost
159+
160+
Get the billed cost of a task. Returns status 'pending' until the cost has been calculated, then
161+
'settled' with the final cost in dollars.
162+
163+
Args:
164+
task_id (UUID):
165+
166+
Raises:
167+
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
168+
httpx.TimeoutException: If the request takes longer than Client.timeout.
169+
170+
Returns:
171+
ErrorResponse | HTTPValidationError | TaskCostResponse
172+
"""
173+
174+
return (
175+
await asyncio_detailed(
176+
task_id=task_id,
177+
client=client,
178+
)
179+
).parsed

src/futuresearch/generated/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
from .subscription_info import SubscriptionInfo
6464
from .subscription_status_response import SubscriptionStatusResponse
6565
from .subscription_tier import SubscriptionTier
66+
from .task_cost_response import TaskCostResponse
67+
from .task_cost_status import TaskCostStatus
6668
from .task_progress_info import TaskProgressInfo
6769
from .task_result_response import TaskResultResponse
6870
from .task_result_response_data_type_0_item import TaskResultResponseDataType0Item
@@ -146,6 +148,8 @@
146148
"SubscriptionInfo",
147149
"SubscriptionStatusResponse",
148150
"SubscriptionTier",
151+
"TaskCostResponse",
152+
"TaskCostStatus",
149153
"TaskProgressInfo",
150154
"TaskResultResponse",
151155
"TaskResultResponseDataType0Item",

src/futuresearch/generated/models/forecast_operation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class ForecastOperation:
2929
question/scenario to forecast.
3030
forecast_type (ForecastOperationForecastType): Type of forecast. 'binary': yes/no probability (0-100) for
3131
questions like 'Will X happen?'. 'numeric': percentile estimates (p10-p90) for questions like 'What will the
32-
price/value/count be?'. Requires output_field when 'numeric'.
32+
price/value/count be?'. 'date': date percentile estimates (p10-p90) as YYYY-MM-DD strings for timing questions
33+
like 'When will X happen?'. Requires output_field when 'numeric' or 'date'.
3334
session_id (None | Unset | UUID): Session ID. If not provided, a new session is auto-created for this task.
3435
webhook_url (None | str | Unset): Optional URL to receive a POST callback when the task completes or fails.
3536
output_field (None | str | Unset): Name of the numeric quantity being forecast (e.g. 'price', 'count'). Required

src/futuresearch/generated/models/forecast_operation_forecast_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
class ForecastOperationForecastType(str, Enum):
55
BINARY = "binary"
6-
NUMERIC = "numeric"
76
DATE = "date"
7+
NUMERIC = "numeric"
88

99
def __str__(self) -> str:
1010
return str(self.value)

src/futuresearch/generated/models/progress_summary_entry.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,16 @@ class ProgressSummaryEntry:
2121
iteration_number (int): Iteration within the trace
2222
summary (str): LLM-generated progress summary text
2323
updated_at (str): When this summary was created
24-
row_index (int | None | Unset): Input row index this trace covers (if resolved)
24+
row_indices (list[int] | Unset): Input row indices this trace covers
25+
row_index (int | None | Unset): Deprecated: use row_indices instead
2526
"""
2627

2728
trace_id: UUID
2829
task_id: UUID
2930
iteration_number: int
3031
summary: str
3132
updated_at: str
33+
row_indices: list[int] | Unset = UNSET
3234
row_index: int | None | Unset = UNSET
3335
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
3436

@@ -43,6 +45,10 @@ def to_dict(self) -> dict[str, Any]:
4345

4446
updated_at = self.updated_at
4547

48+
row_indices: list[int] | Unset = UNSET
49+
if not isinstance(self.row_indices, Unset):
50+
row_indices = self.row_indices
51+
4652
row_index: int | None | Unset
4753
if isinstance(self.row_index, Unset):
4854
row_index = UNSET
@@ -60,6 +66,8 @@ def to_dict(self) -> dict[str, Any]:
6066
"updated_at": updated_at,
6167
}
6268
)
69+
if row_indices is not UNSET:
70+
field_dict["row_indices"] = row_indices
6371
if row_index is not UNSET:
6472
field_dict["row_index"] = row_index
6573

@@ -78,6 +86,8 @@ def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T:
7886

7987
updated_at = d.pop("updated_at")
8088

89+
row_indices = cast(list[int], d.pop("row_indices", UNSET))
90+
8191
def _parse_row_index(data: object) -> int | None | Unset:
8292
if data is None:
8393
return data
@@ -93,6 +103,7 @@ def _parse_row_index(data: object) -> int | None | Unset:
93103
iteration_number=iteration_number,
94104
summary=summary,
95105
updated_at=updated_at,
106+
row_indices=row_indices,
96107
row_index=row_index,
97108
)
98109

0 commit comments

Comments
 (0)