22import typing as t
33
44from asgiref .sync import sync_to_async
5- from django .db .models import Model , QuerySet
5+ from django .db .models import QuerySet
66from pydantic import BaseModel as PydanticModel
77
88from ninja_extra .exceptions import APIException , NotFound
99from ninja_extra .shortcuts import get_object_or_exception
1010
11- from .interfaces import AsyncModelServiceBase , ModelServiceBase
11+ from .interfaces import AsyncModelServiceBase , ModelServiceBase , ModelType
1212
1313
14- class ModelService (ModelServiceBase , AsyncModelServiceBase ):
14+ class ModelService (
15+ ModelServiceBase [ModelType ], AsyncModelServiceBase [ModelType ], t .Generic [ModelType ]
16+ ):
1517 """
1618 Model Service for Model Controller model CRUD operations with a simple logic for simple models.
1719
1820 Its advised you override this class if you have a complex model.
1921 """
2022
21- def __init__ (self , model : t .Type [Model ]) -> None :
23+ def __init__ (self , model : t .Type [ModelType ]) -> None :
2224 self .model = model
2325
2426 def get_one (
@@ -28,14 +30,14 @@ def get_one(
2830 error_message : t .Optional [str ] = None ,
2931 exception : t .Type [APIException ] = NotFound ,
3032 ** kwargs : t .Any ,
31- ) -> t . Any :
33+ ) -> ModelType :
3234 obj = get_object_or_exception (
3335 klass = self .model if queryset is None else queryset ,
3436 error_message = error_message ,
3537 exception = exception ,
3638 pk = pk ,
3739 )
38- return obj
40+ return t . cast ( ModelType , obj )
3941
4042 async def get_one_async (
4143 self ,
@@ -44,7 +46,7 @@ async def get_one_async(
4446 error_message : t .Optional [str ] = None ,
4547 exception : t .Type [APIException ] = NotFound ,
4648 ** kwargs : t .Any ,
47- ) -> t . Any :
49+ ) -> ModelType :
4850 return await sync_to_async (self .get_one , thread_sensitive = True )(
4951 pk ,
5052 queryset = queryset ,
@@ -53,13 +55,13 @@ async def get_one_async(
5355 ** kwargs ,
5456 )
5557
56- def get_all (self , ** kwargs : t .Any ) -> t . Union [ QuerySet , t . List [ t . Any ] ]:
58+ def get_all (self , ** kwargs : t .Any ) -> QuerySet [ ModelType ]:
5759 return self .model .objects .all ()
5860
59- async def get_all_async (self , ** kwargs : t .Any ) -> t . Union [ QuerySet , t . List [ t . Any ] ]:
61+ async def get_all_async (self , ** kwargs : t .Any ) -> QuerySet [ ModelType ]:
6062 return await sync_to_async (self .get_all , thread_sensitive = True )(** kwargs )
6163
62- def create (self , schema : PydanticModel , ** kwargs : t .Any ) -> t . Any :
64+ def create (self , schema : PydanticModel , ** kwargs : t .Any ) -> ModelType :
6365 data = schema .model_dump (by_alias = True )
6466 data .update (kwargs )
6567
@@ -86,10 +88,12 @@ def create(self, schema: PydanticModel, **kwargs: t.Any) -> t.Any:
8688 )
8789 raise TypeError (msg ) from tex
8890
89- async def create_async (self , schema : PydanticModel , ** kwargs : t .Any ) -> t . Any :
91+ async def create_async (self , schema : PydanticModel , ** kwargs : t .Any ) -> ModelType :
9092 return await sync_to_async (self .create , thread_sensitive = True )(schema , ** kwargs )
9193
92- def update (self , instance : Model , schema : PydanticModel , ** kwargs : t .Any ) -> t .Any :
94+ def update (
95+ self , instance : ModelType , schema : PydanticModel , ** kwargs : t .Any
96+ ) -> ModelType :
9397 data = schema .model_dump (exclude_unset = True )
9498 data .update (kwargs )
9599 for attr , value in data .items ():
@@ -98,22 +102,24 @@ def update(self, instance: Model, schema: PydanticModel, **kwargs: t.Any) -> t.A
98102 return instance
99103
100104 async def update_async (
101- self , instance : Model , schema : PydanticModel , ** kwargs : t .Any
102- ) -> t . Any :
105+ self , instance : ModelType , schema : PydanticModel , ** kwargs : t .Any
106+ ) -> ModelType :
103107 return await sync_to_async (self .update , thread_sensitive = True )(
104108 instance , schema , ** kwargs
105109 )
106110
107- def patch (self , instance : Model , schema : PydanticModel , ** kwargs : t .Any ) -> t .Any :
111+ def patch (
112+ self , instance : ModelType , schema : PydanticModel , ** kwargs : t .Any
113+ ) -> ModelType :
108114 return self .update (instance = instance , schema = schema , ** kwargs )
109115
110116 async def patch_async (
111- self , instance : Model , schema : PydanticModel , ** kwargs : t .Any
112- ) -> t . Any :
117+ self , instance : ModelType , schema : PydanticModel , ** kwargs : t .Any
118+ ) -> ModelType :
113119 return await self .update_async (instance = instance , schema = schema , ** kwargs )
114120
115- def delete (self , instance : Model , ** kwargs : t .Any ) -> t .Any :
121+ def delete (self , instance : ModelType , ** kwargs : t .Any ) -> t .Any :
116122 instance .delete ()
117123
118- async def delete_async (self , instance : Model , ** kwargs : t .Any ) -> t .Any :
124+ async def delete_async (self , instance : ModelType , ** kwargs : t .Any ) -> t .Any :
119125 return await sync_to_async (self .delete , thread_sensitive = True )(instance )
0 commit comments