|
2 | 2 |
|
3 | 3 | from collections.abc import Awaitable, Mapping |
4 | 4 | from functools import wraps |
5 | | -from inspect import isasyncgen, isasyncgenfunction, isgenerator |
| 5 | +from inspect import isasyncgen, isgenerator |
6 | 6 | from typing import Any, Callable, ParamSpec |
7 | 7 |
|
8 | 8 | from quart import Response, copy_current_request_context, request, stream_with_context |
@@ -51,9 +51,23 @@ def datastar_response( |
51 | 51 |
|
52 | 52 | @wraps(func) |
53 | 53 | async def wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: |
54 | | - if isasyncgenfunction(func): |
55 | | - return DatastarResponse(stream_with_context(func)(*args, **kwargs)) |
56 | | - return DatastarResponse(await copy_current_request_context(func)(*args, **kwargs)) |
| 54 | + # Preserve request context for whatever we return |
| 55 | + bound_func = copy_current_request_context(func) |
| 56 | + r = bound_func(*args, **kwargs) |
| 57 | + |
| 58 | + if hasattr(r, "__aiter__"): |
| 59 | + return DatastarResponse(stream_with_context(r)) |
| 60 | + |
| 61 | + if hasattr(r, "__iter__") and not isinstance(r, (str, bytes)): |
| 62 | + return DatastarResponse(stream_with_context(r)) |
| 63 | + |
| 64 | + if isinstance(r, Awaitable): |
| 65 | + async def await_and_yield(): |
| 66 | + yield await r |
| 67 | + |
| 68 | + return DatastarResponse(stream_with_context(await_and_yield())) |
| 69 | + |
| 70 | + return DatastarResponse(r) |
57 | 71 |
|
58 | 72 | wrapper.__annotations__["return"] = DatastarResponse |
59 | 73 | return wrapper |
|
0 commit comments