55from ellar .common .interfaces import IExecutionContext
66from ellar .common .logger import logger
77from pydantic import BaseModel , create_model
8+ from pydantic .fields import ModelField
89from starlette .convertors import Convertor
910
1011from .. import params
11- from ..resolvers import BaseRouteParameterResolver
12+ from ..resolvers import (
13+ BulkFormParameterResolver ,
14+ IRouteParameterResolver ,
15+ RouteParameterModelField ,
16+ )
1217from .base import EndpointArgsModel
1318from .extra_args import ExtraEndpointArg
1419
@@ -78,6 +83,16 @@ def __init__(
7883 )
7984 self .operation_unique_id = operation_unique_id
8085
86+ def _get_body_resolver_model_fields (
87+ self , body_resolvers : t .List [IRouteParameterResolver ]
88+ ) -> t .Generator [t .Union ["RouteParameterModelField" , ModelField ], t .Any , None ]:
89+ for resolver in body_resolvers :
90+ if isinstance (resolver , BulkFormParameterResolver ):
91+ for form_resolver in resolver .resolvers :
92+ yield form_resolver .model_field
93+ else :
94+ yield resolver .model_field
95+
8196 def build_body_field (self ) -> None :
8297 """
8398 Group common body / form fields to one field
@@ -91,23 +106,17 @@ def build_body_field(self) -> None:
91106 and len (body_resolvers ) == 1
92107 and not (
93108 body_resolvers [0 ].model_field .field_info .embed # type: ignore[attr-defined]
94- and isinstance (
95- body_resolvers [0 ].model_field .field_info , params .BodyFieldInfo # type: ignore[attr-defined]
96- )
97109 )
98110 ):
99- check_file_field (body_resolvers [0 ].model_field ) # type: ignore[attr-defined]
111+ check_file_field (body_resolvers [0 ].model_field )
100112 self .body_resolver = body_resolvers [0 ]
101113 elif body_resolvers :
102114 # if body_resolvers is more than one, we create a bulk_body_resolver instead
103- _body_resolvers_model_fields = (
104- t .cast (BaseRouteParameterResolver , item ).model_field
105- for item in body_resolvers
106- )
107115 model_name = "body_" + self .operation_unique_id
108116 body_model_field : t .Type [BaseModel ] = create_model (model_name )
109117 _fields_required , _body_param_class = [], {}
110- for f in _body_resolvers_model_fields :
118+
119+ for f in self ._get_body_resolver_model_fields (body_resolvers ):
111120 f .field_info .embed = True # type:ignore[attr-defined]
112121 body_model_field .__fields__ [f .name ] = f
113122 _fields_required .append (f .required )
@@ -122,12 +131,12 @@ def build_body_field(self) -> None:
122131 media_type = "application/json"
123132 if len (_body_param_class ) == 1 :
124133 _ , (klass , field_info ) = _body_param_class .popitem ()
125- body_field_info = klass
134+ body_field_info = klass # type:ignore[assignment]
126135 media_type = getattr (field_info , "media_type" , media_type )
127136 elif len (_body_param_class ) > 1 :
128137 key = sorted (_body_param_class .keys (), reverse = True )[0 ]
129138 klass , field_info = _body_param_class [key ]
130- body_field_info = klass
139+ body_field_info = klass # type:ignore[assignment]
131140 media_type = getattr (field_info , "media_type" , media_type )
132141
133142 final_field = create_model_field (
0 commit comments