|
1 | 1 | import logging |
2 | 2 | import traceback |
3 | | -from typing import Any |
| 3 | +from typing import Any, Generic, TypeVar |
4 | 4 |
|
5 | 5 | from fastapi import HTTPException, Request, status |
6 | 6 | from fastapi.datastructures import URL |
|
50 | 50 |
|
51 | 51 | logger = logging.getLogger(__name__) |
52 | 52 |
|
| 53 | +T = TypeVar("T", bound=OrderStatus) |
53 | 54 |
|
54 | | -class RootRouter(StapiFastapiBaseRouter): |
| 55 | + |
| 56 | +class RootRouter(StapiFastapiBaseRouter, Generic[T]): |
55 | 57 | def __init__( |
56 | 58 | self, |
57 | | - get_orders: GetOrders, |
| 59 | + get_orders: GetOrders[T], |
58 | 60 | get_order: GetOrder, |
59 | 61 | get_order_statuses: GetOrderStatuses | None = None, # type: ignore |
60 | 62 | get_opportunity_search_records: GetOpportunitySearchRecords | None = None, |
@@ -240,7 +242,7 @@ def get_products(self, request: Request, next: str | None = None, limit: int = 1 |
240 | 242 |
|
241 | 243 | async def get_orders( # noqa: C901 |
242 | 244 | self, request: Request, next: str | None = None, limit: int = 10 |
243 | | - ) -> OrderCollection[OrderStatus]: |
| 245 | + ) -> OrderCollection[T]: |
244 | 246 | links: list[Link] = [] |
245 | 247 | orders_count: int | None = None |
246 | 248 | match await self._get_orders(next, limit, request): |
@@ -271,13 +273,13 @@ async def get_orders( # noqa: C901 |
271 | 273 | case _: |
272 | 274 | raise AssertionError("Expected code to be unreachable") |
273 | 275 |
|
274 | | - return OrderCollection( |
| 276 | + return OrderCollection[T]( |
275 | 277 | features=orders, |
276 | 278 | links=links, |
277 | 279 | number_matched=orders_count, |
278 | 280 | ) |
279 | 281 |
|
280 | | - async def get_order(self, order_id: str, request: Request) -> Order[OrderStatus]: |
| 282 | + async def get_order(self, order_id: str, request: Request) -> Order[T]: |
281 | 283 | """ |
282 | 284 | Get details for order with `order_id`. |
283 | 285 | """ |
@@ -306,7 +308,7 @@ async def get_order_statuses( |
306 | 308 | request: Request, |
307 | 309 | next: str | None = None, |
308 | 310 | limit: int = 10, |
309 | | - ) -> OrderStatuses: # type: ignore |
| 311 | + ) -> OrderStatuses[T]: |
310 | 312 | links: list[Link] = [] |
311 | 313 | match await self._get_order_statuses(order_id, next, limit, request): |
312 | 314 | case Success(Some((statuses, maybe_pagination_token))): |
@@ -350,7 +352,7 @@ def generate_order_href(self, request: Request, order_id: str) -> URL: |
350 | 352 | def generate_order_statuses_href(self, request: Request, order_id: str) -> URL: |
351 | 353 | return self.url_for(request, f"{self.name}:{LIST_ORDER_STATUSES}", order_id=order_id) |
352 | 354 |
|
353 | | - def order_links(self, order: Order[OrderStatus], request: Request) -> list[Link]: |
| 355 | + def order_links(self, order: Order[T], request: Request) -> list[Link]: |
354 | 356 | return [ |
355 | 357 | Link( |
356 | 358 | href=self.generate_order_href(request, order.id), |
@@ -464,7 +466,7 @@ def opportunity_search_record_self_link( |
464 | 466 | return json_link("self", self.generate_opportunity_search_record_href(request, opportunity_search_record.id)) |
465 | 467 |
|
466 | 468 | @property |
467 | | - def _get_order_statuses(self) -> GetOrderStatuses: # type: ignore |
| 469 | + def _get_order_statuses(self) -> GetOrderStatuses[T]: |
468 | 470 | if not self.__get_order_statuses: |
469 | 471 | raise AttributeError("Root router does not support order status history") |
470 | 472 | return self.__get_order_statuses |
|
0 commit comments