Skip to content

Commit bb8f532

Browse files
authored
Merge pull request #360 from kennyhei/master
Improve return types for ModelService CRUD methods
2 parents de01355 + 80d05f0 commit bb8f532

3 files changed

Lines changed: 38 additions & 30 deletions

File tree

ninja_extra/controllers/model/endpoints.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ async def delete_item(self: "ModelControllerBase", **kwargs: t.Any) -> t.Any:
752752
res = (
753753
custom_handler(self, instance=obj, **kwargs)
754754
if custom_handler
755-
else self.service.delete_async(instance=obj, **kwargs) # type:ignore[arg-type]
755+
else self.service.delete_async(instance=obj, **kwargs)
756756
)
757757

758758
await _check_if_coroutine(res)
@@ -815,7 +815,7 @@ async def patch_item(
815815
instance = (
816816
custom_handler(self, instance=obj, schema=data, **kwargs)
817817
if custom_handler
818-
else self.service.patch_async(instance=obj, schema=data, **kwargs) # type:ignore[arg-type]
818+
else self.service.patch_async(instance=obj, schema=data, **kwargs)
819819
)
820820
instance = await _check_if_coroutine(instance)
821821

@@ -856,7 +856,7 @@ async def update_item(
856856
instance = (
857857
custom_handler(self, instance=obj, schema=data, **kwargs)
858858
if custom_handler
859-
else self.service.update_async(instance=obj, schema=data, **kwargs) # type:ignore[arg-type]
859+
else self.service.update_async(instance=obj, schema=data, **kwargs)
860860
)
861861
instance = await _check_if_coroutine(instance)
862862

ninja_extra/controllers/model/interfaces.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77

88
from ninja_extra.exceptions import APIException, NotFound
99

10+
ModelType = t.TypeVar("ModelType", bound=DjangoModel)
1011

11-
class AsyncModelServiceBase(ABC):
12+
13+
class AsyncModelServiceBase(ABC, t.Generic[ModelType]):
1214
@abstractmethod
1315
async def get_one_async(
1416
self,
@@ -30,22 +32,22 @@ async def create_async(self, schema: PydanticModel, **kwargs: t.Any) -> t.Any:
3032

3133
@abstractmethod
3234
async def update_async(
33-
self, instance: DjangoModel, schema: PydanticModel, **kwargs: t.Any
35+
self, instance: ModelType, schema: PydanticModel, **kwargs: t.Any
3436
) -> t.Any:
3537
pass
3638

3739
@abstractmethod
3840
async def patch_async(
39-
self, instance: DjangoModel, schema: PydanticModel, **kwargs: t.Any
41+
self, instance: ModelType, schema: PydanticModel, **kwargs: t.Any
4042
) -> t.Any:
4143
pass
4244

4345
@abstractmethod
44-
async def delete_async(self, instance: DjangoModel, **kwargs: t.Any) -> t.Any:
46+
async def delete_async(self, instance: ModelType, **kwargs: t.Any) -> t.Any:
4547
pass
4648

4749

48-
class ModelServiceBase(ABC):
50+
class ModelServiceBase(ABC, t.Generic[ModelType]):
4951
"""
5052
Abstract service that handles Model Controller model CRUD operations
5153
"""
@@ -71,16 +73,16 @@ def create(self, schema: PydanticModel, **kwargs: t.Any) -> t.Any:
7173

7274
@abstractmethod
7375
def update(
74-
self, instance: DjangoModel, schema: PydanticModel, **kwargs: t.Any
76+
self, instance: ModelType, schema: PydanticModel, **kwargs: t.Any
7577
) -> t.Any:
7678
pass
7779

7880
@abstractmethod
7981
def patch(
80-
self, instance: DjangoModel, schema: PydanticModel, **kwargs: t.Any
82+
self, instance: ModelType, schema: PydanticModel, **kwargs: t.Any
8183
) -> t.Any:
8284
pass
8385

8486
@abstractmethod
85-
def delete(self, instance: DjangoModel, **kwargs: t.Any) -> t.Any:
87+
def delete(self, instance: ModelType, **kwargs: t.Any) -> t.Any:
8688
pass

ninja_extra/controllers/model/service.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,25 @@
22
import typing as t
33

44
from asgiref.sync import sync_to_async
5-
from django.db.models import Model, QuerySet
5+
from django.db.models import QuerySet
66
from pydantic import BaseModel as PydanticModel
77

88
from ninja_extra.exceptions import APIException, NotFound
99
from ninja_extra.shortcuts import get_object_or_exception
1010

11-
from .interfaces import AsyncModelServiceBase, ModelServiceBase
11+
from .interfaces import AsyncModelServiceBase, ModelServiceBase, ModelType
1212

1313

14-
class ModelService(ModelServiceBase, AsyncModelServiceBase):
14+
class ModelService(
15+
ModelServiceBase[ModelType], AsyncModelServiceBase[ModelType], t.Generic[ModelType]
16+
):
1517
"""
1618
Model Service for Model Controller model CRUD operations with a simple logic for simple models.
1719
1820
Its advised you override this class if you have a complex model.
1921
"""
2022

21-
def __init__(self, model: t.Type[Model]) -> None:
23+
def __init__(self, model: t.Type[ModelType]) -> None:
2224
self.model = model
2325

2426
def get_one(
@@ -28,14 +30,14 @@ def get_one(
2830
error_message: t.Optional[str] = None,
2931
exception: t.Type[APIException] = NotFound,
3032
**kwargs: t.Any,
31-
) -> t.Any:
33+
) -> ModelType:
3234
obj = get_object_or_exception(
3335
klass=self.model if queryset is None else queryset,
3436
error_message=error_message,
3537
exception=exception,
3638
pk=pk,
3739
)
38-
return obj
40+
return t.cast(ModelType, obj)
3941

4042
async def get_one_async(
4143
self,
@@ -44,7 +46,7 @@ async def get_one_async(
4446
error_message: t.Optional[str] = None,
4547
exception: t.Type[APIException] = NotFound,
4648
**kwargs: t.Any,
47-
) -> t.Any:
49+
) -> ModelType:
4850
return await sync_to_async(self.get_one, thread_sensitive=True)(
4951
pk,
5052
queryset=queryset,
@@ -53,13 +55,13 @@ async def get_one_async(
5355
**kwargs,
5456
)
5557

56-
def get_all(self, **kwargs: t.Any) -> t.Union[QuerySet, t.List[t.Any]]:
58+
def get_all(self, **kwargs: t.Any) -> QuerySet[ModelType]:
5759
return self.model.objects.all()
5860

59-
async def get_all_async(self, **kwargs: t.Any) -> t.Union[QuerySet, t.List[t.Any]]:
61+
async def get_all_async(self, **kwargs: t.Any) -> QuerySet[ModelType]:
6062
return await sync_to_async(self.get_all, thread_sensitive=True)(**kwargs)
6163

62-
def create(self, schema: PydanticModel, **kwargs: t.Any) -> t.Any:
64+
def create(self, schema: PydanticModel, **kwargs: t.Any) -> ModelType:
6365
data = schema.model_dump(by_alias=True)
6466
data.update(kwargs)
6567

@@ -86,10 +88,12 @@ def create(self, schema: PydanticModel, **kwargs: t.Any) -> t.Any:
8688
)
8789
raise TypeError(msg) from tex
8890

89-
async def create_async(self, schema: PydanticModel, **kwargs: t.Any) -> t.Any:
91+
async def create_async(self, schema: PydanticModel, **kwargs: t.Any) -> ModelType:
9092
return await sync_to_async(self.create, thread_sensitive=True)(schema, **kwargs)
9193

92-
def update(self, instance: Model, schema: PydanticModel, **kwargs: t.Any) -> t.Any:
94+
def update(
95+
self, instance: ModelType, schema: PydanticModel, **kwargs: t.Any
96+
) -> ModelType:
9397
data = schema.model_dump(exclude_unset=True)
9498
data.update(kwargs)
9599
for attr, value in data.items():
@@ -98,22 +102,24 @@ def update(self, instance: Model, schema: PydanticModel, **kwargs: t.Any) -> t.A
98102
return instance
99103

100104
async def update_async(
101-
self, instance: Model, schema: PydanticModel, **kwargs: t.Any
102-
) -> t.Any:
105+
self, instance: ModelType, schema: PydanticModel, **kwargs: t.Any
106+
) -> ModelType:
103107
return await sync_to_async(self.update, thread_sensitive=True)(
104108
instance, schema, **kwargs
105109
)
106110

107-
def patch(self, instance: Model, schema: PydanticModel, **kwargs: t.Any) -> t.Any:
111+
def patch(
112+
self, instance: ModelType, schema: PydanticModel, **kwargs: t.Any
113+
) -> ModelType:
108114
return self.update(instance=instance, schema=schema, **kwargs)
109115

110116
async def patch_async(
111-
self, instance: Model, schema: PydanticModel, **kwargs: t.Any
112-
) -> t.Any:
117+
self, instance: ModelType, schema: PydanticModel, **kwargs: t.Any
118+
) -> ModelType:
113119
return await self.update_async(instance=instance, schema=schema, **kwargs)
114120

115-
def delete(self, instance: Model, **kwargs: t.Any) -> t.Any:
121+
def delete(self, instance: ModelType, **kwargs: t.Any) -> t.Any:
116122
instance.delete()
117123

118-
async def delete_async(self, instance: Model, **kwargs: t.Any) -> t.Any:
124+
async def delete_async(self, instance: ModelType, **kwargs: t.Any) -> t.Any:
119125
return await sync_to_async(self.delete, thread_sensitive=True)(instance)

0 commit comments

Comments
 (0)