diff --git a/pyflowlauncher/event.py b/pyflowlauncher/event.py index 7e8a06a..3446011 100644 --- a/pyflowlauncher/event.py +++ b/pyflowlauncher/event.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Iterable, Type, Union from .result import Result, send_results +from .response import _collect_item class EventNotFound(Exception): @@ -41,15 +42,16 @@ def get_event(self, event: str) -> Callable[..., Any]: async def _await_maybe(self, result: Any) -> Any: if asyncio.iscoroutine(result): - return await result + return await self._await_maybe(await result) if inspect.isasyncgen(result): results = [] async for item in result: - if isinstance(item, Result): - results.append(item) - elif isinstance(item, list): - results.extend(r for r in item if isinstance(r, Result)) + results.extend(_collect_item(item)) return send_results(results) + if isinstance(result, Result): + return send_results([result]) + if isinstance(result, list): + return send_results([r for r in result if isinstance(r, Result)]) return result async def trigger_exception_handler(self, exception: Exception) -> Any: diff --git a/pyflowlauncher/plugin.py b/pyflowlauncher/plugin.py index 3e75b3c..0e2b16e 100644 --- a/pyflowlauncher/plugin.py +++ b/pyflowlauncher/plugin.py @@ -30,18 +30,20 @@ def __init__(self, methods: list[Method] | None = None) -> None: def add_method(self, method: Method) -> str: """Add a method to the event handler.""" + @wraps(method) + def wrapper(*args, **kwargs): + return handle_response(method(*args, **kwargs)) + setattr(wrapper, '_is_registered_method', True) setattr(method, '_is_registered_method', True) - return self._event_handler.add_event(method) + return self._event_handler.add_event(wrapper) def add_methods(self, methods: Iterable[Method]) -> None: - self._event_handler.add_events(methods) + for method in methods: + self.add_method(method) def on_method(self, method: Method) -> Method: - @wraps(method) - def wrapper(*args, **kwargs): - return handle_response(method(*args, **kwargs)) - self.add_method(wrapper) - return wrapper + self.add_method(method) + return method def method(self, method: Method) -> Method: """Register a method to be called when the plugin is run.""" diff --git a/pyflowlauncher/response.py b/pyflowlauncher/response.py index 3287196..8625037 100644 --- a/pyflowlauncher/response.py +++ b/pyflowlauncher/response.py @@ -17,17 +17,22 @@ def handle_response(result: Any) -> Union[JsonRPCResponse, JsonRPCRequest, None] if isinstance(result, Result): return send_results([result]) if isinstance(result, list): - return send_results(result) + return send_results([r for r in result if isinstance(r, Result)]) if inspect.isgenerator(result): return _collect_generator(result) return result +def _collect_item(item: Any) -> list[Result]: + if isinstance(item, Result): + return [item] + if isinstance(item, list): + return [r for r in item if isinstance(r, Result)] + return [] + + def _collect_generator(gen: Generator) -> JsonRPCResponse: results = [] for item in gen: - if isinstance(item, Result): - results.append(item) - elif isinstance(item, list): - results.extend(r for r in item if isinstance(r, Result)) + results.extend(_collect_item(item)) return send_results(results) diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 7f0ec1b..4cc8d37 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -17,13 +17,14 @@ def query(query: str): def test_add_method(): plugin = Plugin() plugin.add_method(temp_method1) - assert plugin._event_handler._events == {'temp_method1': temp_method1} + assert 'temp_method1' in plugin._event_handler._events def test_add_methods(): plugin = Plugin() plugin.add_methods([temp_method1, temp_method2]) - assert plugin._event_handler._events == {'temp_method1': temp_method1, 'temp_method2': temp_method2} + assert 'temp_method1' in plugin._event_handler._events + assert 'temp_method2' in plugin._event_handler._events def test_settings(): diff --git a/tests/test_response.py b/tests/test_response.py index a701be0..be7024d 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -137,3 +137,69 @@ async def query(q: str): result = await plugin._event_handler.trigger_event("query", "test") assert result == send_results([Result(title="x"), Result(title="y")]) + + +# --- Regression: _is_registered_method on on_method return value --- + +def test_on_method_result_has_registered_flag(): + plugin = Plugin() + + @plugin.on_method + def query(q: str): + return Result(title="x") + + assert getattr(query, '_is_registered_method', False) is True + + +def test_add_method_result_has_registered_flag(): + plugin = Plugin() + + def query(q: str): + return Result(title="x") + + plugin.add_method(query) + assert getattr(query, '_is_registered_method', False) is True + + +def test_on_method_add_action_does_not_raise(): + plugin = Plugin() + + @plugin.on_method + def action_handler(): + pass + + r = Result(title="x") + r.add_action(action_handler) # should not raise MethodNotRegisteredError + + +# --- Regression: async def returning Result/list --- + +@pytest.mark.asyncio +async def test_async_return_single_result(): + plugin = Plugin() + + @plugin.on_method + async def query(q: str): + return Result(title="async-return") + + result = await plugin._event_handler.trigger_event("query", "test") + assert result == send_results([Result(title="async-return")]) + + +@pytest.mark.asyncio +async def test_async_return_list_of_results(): + plugin = Plugin() + + @plugin.on_method + async def query(q: str): + return [Result(title="a"), Result(title="b")] + + result = await plugin._event_handler.trigger_event("query", "test") + assert result == send_results([Result(title="a"), Result(title="b")]) + + +# --- Regression: list branch filters non-Result items --- + +def test_list_with_non_result_items_filtered(): + results = [Result(title="a"), "not a result", 42] + assert handle_response(results) == send_results([Result(title="a")])