Skip to content

Commit c7e3d47

Browse files
committed
Ensure ActionDescriptor.output_model is always a model.
I've used `wrap_plain_types_in_rootmodel` to encapsulate action return types in a RootModel if they are not already BaseModel subclasses. This stops annotations from being lost when actions are returned. I needed to slightly adjust a test that checked action output models, I think that change is uncontroversial.
1 parent 4cbc085 commit c7e3d47

2 files changed

Lines changed: 6 additions & 8 deletions

File tree

src/labthings_fastapi/actions.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
from .base_descriptor import BaseDescriptor
4343
from .logs import add_thing_log_destination
44-
from .utilities import model_to_dict
44+
from .utilities import model_to_dict, wrap_plain_types_in_rootmodel
4545
from .invocations import InvocationModel, InvocationStatus, LogRecordModel
4646
from .dependencies.invocation import NonWarningInvocationID
4747
from .exceptions import (
@@ -477,7 +477,6 @@ def list_all_invocations(
477477

478478
@app.get(
479479
ACTION_INVOCATIONS_PATH + "/{id}",
480-
response_model=InvocationModel,
481480
responses={404: {"description": "Invocation ID not found"}},
482481
)
483482
def action_invocation(
@@ -683,7 +682,7 @@ def __init__(
683682
remove_first_positional_arg=True,
684683
ignore=[p.name for p in self.dependency_params],
685684
)
686-
self.output_model = return_type(func)
685+
self.output_model = wrap_plain_types_in_rootmodel(return_type(func))
687686
self.invocation_model = create_model(
688687
f"{name}_invocation",
689688
__base__=InvocationModel,

tests/test_actions.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def action_wrapper(*args, **kwargs):
161161
return action_wrapper
162162

163163

164-
def assert_input_models_equivalent(model_a, model_b):
164+
def assert_models_equivalent(model_a, model_b):
165165
"""Check two basemodels are equivalent."""
166166
keys = list(model_a.model_fields.keys())
167167
assert list(model_b.model_fields.keys()) == keys
@@ -198,11 +198,10 @@ def decorated(
198198
"""An example decorated action with type annotations."""
199199
return 0.5
200200

201-
assert_input_models_equivalent(
202-
Example.action.input_model, Example.decorated.input_model
201+
assert_models_equivalent(Example.action.input_model, Example.decorated.input_model)
202+
assert_models_equivalent(
203+
Example.action.output_model, Example.decorated.output_model
203204
)
204-
assert Example.action.output_model == Example.decorated.output_model
205-
206205
# Check we can make the thing and it has a valid TD
207206
example = create_thing_without_server(Example)
208207
example.validate_thing_description()

0 commit comments

Comments
 (0)