1616from typing import List
1717
1818from casbin import persist
19+ from casbin .persist .adapters .asyncio import AsyncAdapter
1920from sqlalchemy import Column , Integer , String , delete
2021from sqlalchemy import or_
2122from sqlalchemy .ext .asyncio import create_async_engine , AsyncSession
@@ -59,7 +60,7 @@ class Filter:
5960 v5 = []
6061
6162
62- class Adapter (persist . Adapter ):
63+ class Adapter (AsyncAdapter ):
6364 """the interface for Casbin adapters."""
6465
6566 def __init__ (self , engine , db_class = None , filtered = False , warning = True ):
@@ -72,12 +73,21 @@ def __init__(self, engine, db_class=None, filtered=False, warning=True):
7273 db_class = CasbinRule
7374 if warning :
7475 warnings .warn (
75- ' Using default CasbinRule table, please note the use of the " Adapter().create_table()" method to '
76- ' create the table, and ignore this warning if you are using a custom CasbinRule table.' ,
76+ " Using default CasbinRule table, please note the use of the ' Adapter().create_table()' method"
77+ " to create the table, and ignore this warning if you are using a custom CasbinRule table." ,
7778 RuntimeWarning ,
7879 )
7980 else :
80- for attr in ("id" , "ptype" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" ): # id attr was used by filter
81+ for attr in (
82+ "id" ,
83+ "ptype" ,
84+ "v0" ,
85+ "v1" ,
86+ "v2" ,
87+ "v3" ,
88+ "v4" ,
89+ "v5" ,
90+ ): # id attr was used by filter
8191 if not hasattr (db_class , attr ):
8292 raise Exception (f"{ attr } not found in custom DatabaseClass." )
8393 Base .metadata = db_class .metadata
@@ -124,7 +134,7 @@ async def load_filtered_policy(self, model, filter) -> None:
124134 for line in result .scalars ():
125135 persist .load_policy_line (str (line ), model )
126136 self ._filtered = True
127-
137+
128138 def filter_query (self , stmt , filter ):
129139 for attr in ("ptype" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" ):
130140 if len (getattr (filter , attr )) > 0 :
@@ -204,7 +214,9 @@ async def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
204214
205215 return True if r .rowcount > 0 else False
206216
207- async def update_policy (self , sec : str , ptype : str , old_rule : List [str ], new_rule : List [str ]) -> None :
217+ async def update_policy (
218+ self , sec : str , ptype : str , old_rule : List [str ], new_rule : List [str ]
219+ ) -> None :
208220 """
209221 Update the old_rule with the new_rule in the database (storage).
210222
@@ -236,7 +248,13 @@ async def update_policy(self, sec: str, ptype: str, old_rule: List[str], new_rul
236248 else :
237249 setattr (old_rule_line , "v{}" .format (index ), None )
238250
239- async def update_policies (self , sec : str , ptype : str , old_rules : List [List [str ]], new_rules : List [List [str ]]) -> None :
251+ async def update_policies (
252+ self ,
253+ sec : str ,
254+ ptype : str ,
255+ old_rules : List [List [str ]],
256+ new_rules : List [List [str ]],
257+ ) -> None :
240258 """
241259 Update the old_rules with the new_rules in the database (storage).
242260
@@ -250,7 +268,9 @@ async def update_policies(self, sec: str, ptype: str, old_rules: List[List[str]]
250268 for i in range (len (old_rules )):
251269 await self .update_policy (sec , ptype , old_rules [i ], new_rules [i ])
252270
253- async def update_filtered_policies (self , sec , ptype , new_rules : List [List [str ]], field_index , * field_values ) -> List [List [str ]]:
271+ async def update_filtered_policies (
272+ self , sec , ptype , new_rules : List [List [str ]], field_index , * field_values
273+ ) -> List [List [str ]]:
254274 """update_filtered_policies updates all the policies on the basis of the filter."""
255275
256276 filter = Filter ()
@@ -271,9 +291,7 @@ async def _update_filtered_policies(self, new_rules, filter) -> List[List[str]]:
271291 async with self ._session_scope () as session :
272292 # Load old policies
273293
274- stmt = select (self ._db_class ).where (
275- self ._db_class .ptype == filter .ptype
276- )
294+ stmt = select (self ._db_class ).where (self ._db_class .ptype == filter .ptype )
277295 filtered_stmt = self .filter_query (stmt , filter )
278296 result = await session .execute (filtered_stmt )
279297 old_rules = result .scalars ().all ()
0 commit comments