Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions pyflowlauncher/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Callable, Iterable, Type, Union

from .result import Result, send_results
from .response import _collect_item


class EventNotFound(Exception):
Expand Down Expand Up @@ -41,15 +42,16 @@ def get_event(self, event: str) -> Callable[..., Any]:

async def _await_maybe(self, result: Any) -> Any:
if asyncio.iscoroutine(result):
return await result
return await self._await_maybe(await result)
if inspect.isasyncgen(result):
results = []
async for item in result:
if isinstance(item, Result):
results.append(item)
elif isinstance(item, list):
results.extend(r for r in item if isinstance(r, Result))
results.extend(_collect_item(item))
return send_results(results)
if isinstance(result, Result):
return send_results([result])
if isinstance(result, list):
return send_results([r for r in result if isinstance(r, Result)])
return result

async def trigger_exception_handler(self, exception: Exception) -> Any:
Expand Down
16 changes: 9 additions & 7 deletions pyflowlauncher/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,20 @@ def __init__(self, methods: list[Method] | None = None) -> None:

def add_method(self, method: Method) -> str:
"""Add a method to the event handler."""
@wraps(method)
def wrapper(*args, **kwargs):
return handle_response(method(*args, **kwargs))
setattr(wrapper, '_is_registered_method', True)
setattr(method, '_is_registered_method', True)
return self._event_handler.add_event(method)
return self._event_handler.add_event(wrapper)

def add_methods(self, methods: Iterable[Method]) -> None:
self._event_handler.add_events(methods)
for method in methods:
self.add_method(method)

def on_method(self, method: Method) -> Method:
@wraps(method)
def wrapper(*args, **kwargs):
return handle_response(method(*args, **kwargs))
self.add_method(wrapper)
return wrapper
self.add_method(method)
return method

def method(self, method: Method) -> Method:
"""Register a method to be called when the plugin is run."""
Expand Down
15 changes: 10 additions & 5 deletions pyflowlauncher/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,22 @@ def handle_response(result: Any) -> Union[JsonRPCResponse, JsonRPCRequest, None]
if isinstance(result, Result):
return send_results([result])
if isinstance(result, list):
return send_results(result)
return send_results([r for r in result if isinstance(r, Result)])
if inspect.isgenerator(result):
return _collect_generator(result)
return result


def _collect_item(item: Any) -> list[Result]:
if isinstance(item, Result):
return [item]
if isinstance(item, list):
return [r for r in item if isinstance(r, Result)]
return []


def _collect_generator(gen: Generator) -> JsonRPCResponse:
results = []
for item in gen:
if isinstance(item, Result):
results.append(item)
elif isinstance(item, list):
results.extend(r for r in item if isinstance(r, Result))
results.extend(_collect_item(item))
return send_results(results)
5 changes: 3 additions & 2 deletions tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ def query(query: str):
def test_add_method():
plugin = Plugin()
plugin.add_method(temp_method1)
assert plugin._event_handler._events == {'temp_method1': temp_method1}
assert 'temp_method1' in plugin._event_handler._events


def test_add_methods():
plugin = Plugin()
plugin.add_methods([temp_method1, temp_method2])
assert plugin._event_handler._events == {'temp_method1': temp_method1, 'temp_method2': temp_method2}
assert 'temp_method1' in plugin._event_handler._events
assert 'temp_method2' in plugin._event_handler._events


def test_settings():
Expand Down
66 changes: 66 additions & 0 deletions tests/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,69 @@ async def query(q: str):

result = await plugin._event_handler.trigger_event("query", "test")
assert result == send_results([Result(title="x"), Result(title="y")])


# --- Regression: _is_registered_method on on_method return value ---

def test_on_method_result_has_registered_flag():
plugin = Plugin()

@plugin.on_method
def query(q: str):
return Result(title="x")

assert getattr(query, '_is_registered_method', False) is True


def test_add_method_result_has_registered_flag():
plugin = Plugin()

def query(q: str):
return Result(title="x")

plugin.add_method(query)
assert getattr(query, '_is_registered_method', False) is True


def test_on_method_add_action_does_not_raise():
plugin = Plugin()

@plugin.on_method
def action_handler():
pass

r = Result(title="x")
r.add_action(action_handler) # should not raise MethodNotRegisteredError


# --- Regression: async def returning Result/list ---

@pytest.mark.asyncio
async def test_async_return_single_result():
plugin = Plugin()

@plugin.on_method
async def query(q: str):
return Result(title="async-return")

result = await plugin._event_handler.trigger_event("query", "test")
assert result == send_results([Result(title="async-return")])


@pytest.mark.asyncio
async def test_async_return_list_of_results():
plugin = Plugin()

@plugin.on_method
async def query(q: str):
return [Result(title="a"), Result(title="b")]

result = await plugin._event_handler.trigger_event("query", "test")
assert result == send_results([Result(title="a"), Result(title="b")])


# --- Regression: list branch filters non-Result items ---

def test_list_with_non_result_items_filtered():
results = [Result(title="a"), "not a result", 42]
assert handle_response(results) == send_results([Result(title="a")])
Loading