Skip to content

Commit cf53210

Browse files
authored
feat: update adapter interface stub (#12)
1 parent ce8bd5e commit cf53210

3 files changed

Lines changed: 40 additions & 21 deletions

File tree

casbin_async_sqlalchemy_adapter/adapter.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import List
1717

1818
from casbin import persist
19+
from casbin.persist.adapters.asyncio import AsyncAdapter
1920
from sqlalchemy import Column, Integer, String, delete
2021
from sqlalchemy import or_
2122
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
@@ -59,7 +60,7 @@ class Filter:
5960
v5 = []
6061

6162

62-
class Adapter(persist.Adapter):
63+
class Adapter(AsyncAdapter):
6364
"""the interface for Casbin adapters."""
6465

6566
def __init__(self, engine, db_class=None, filtered=False, warning=True):
@@ -72,12 +73,21 @@ def __init__(self, engine, db_class=None, filtered=False, warning=True):
7273
db_class = CasbinRule
7374
if warning:
7475
warnings.warn(
75-
'Using default CasbinRule table, please note the use of the "Adapter().create_table()" method to '
76-
'create the table, and ignore this warning if you are using a custom CasbinRule table.',
76+
"Using default CasbinRule table, please note the use of the 'Adapter().create_table()' method"
77+
" to create the table, and ignore this warning if you are using a custom CasbinRule table.",
7778
RuntimeWarning,
7879
)
7980
else:
80-
for attr in ("id", "ptype", "v0", "v1", "v2", "v3", "v4", "v5"): # id attr was used by filter
81+
for attr in (
82+
"id",
83+
"ptype",
84+
"v0",
85+
"v1",
86+
"v2",
87+
"v3",
88+
"v4",
89+
"v5",
90+
): # id attr was used by filter
8191
if not hasattr(db_class, attr):
8292
raise Exception(f"{attr} not found in custom DatabaseClass.")
8393
Base.metadata = db_class.metadata
@@ -124,7 +134,7 @@ async def load_filtered_policy(self, model, filter) -> None:
124134
for line in result.scalars():
125135
persist.load_policy_line(str(line), model)
126136
self._filtered = True
127-
137+
128138
def filter_query(self, stmt, filter):
129139
for attr in ("ptype", "v0", "v1", "v2", "v3", "v4", "v5"):
130140
if len(getattr(filter, attr)) > 0:
@@ -204,7 +214,9 @@ async def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
204214

205215
return True if r.rowcount > 0 else False
206216

207-
async def update_policy(self, sec: str, ptype: str, old_rule: List[str], new_rule: List[str]) -> None:
217+
async def update_policy(
218+
self, sec: str, ptype: str, old_rule: List[str], new_rule: List[str]
219+
) -> None:
208220
"""
209221
Update the old_rule with the new_rule in the database (storage).
210222
@@ -236,7 +248,13 @@ async def update_policy(self, sec: str, ptype: str, old_rule: List[str], new_rul
236248
else:
237249
setattr(old_rule_line, "v{}".format(index), None)
238250

239-
async def update_policies(self, sec: str, ptype: str, old_rules: List[List[str]], new_rules: List[List[str]]) -> None:
251+
async def update_policies(
252+
self,
253+
sec: str,
254+
ptype: str,
255+
old_rules: List[List[str]],
256+
new_rules: List[List[str]],
257+
) -> None:
240258
"""
241259
Update the old_rules with the new_rules in the database (storage).
242260
@@ -250,7 +268,9 @@ async def update_policies(self, sec: str, ptype: str, old_rules: List[List[str]]
250268
for i in range(len(old_rules)):
251269
await self.update_policy(sec, ptype, old_rules[i], new_rules[i])
252270

253-
async def update_filtered_policies(self, sec, ptype, new_rules: List[List[str]], field_index, *field_values) -> List[List[str]]:
271+
async def update_filtered_policies(
272+
self, sec, ptype, new_rules: List[List[str]], field_index, *field_values
273+
) -> List[List[str]]:
254274
"""update_filtered_policies updates all the policies on the basis of the filter."""
255275

256276
filter = Filter()
@@ -271,9 +291,7 @@ async def _update_filtered_policies(self, new_rules, filter) -> List[List[str]]:
271291
async with self._session_scope() as session:
272292
# Load old policies
273293

274-
stmt = select(self._db_class).where(
275-
self._db_class.ptype == filter.ptype
276-
)
294+
stmt = select(self._db_class).where(self._db_class.ptype == filter.ptype)
277295
filtered_stmt = self.filter_query(stmt, filter)
278296
result = await session.execute(filtered_stmt)
279297
old_rules = result.scalars().all()

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
SQLAlchemy>=1.4.0
2-
casbin>=1.23.0
2+
casbin>=1.34.0

tests/test_adapter.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818

1919
import casbin
2020
from sqlalchemy import Column, Integer, String, select
21-
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
22-
from sqlalchemy.orm import sessionmaker
21+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
2322

2423
from casbin_async_sqlalchemy_adapter import Adapter
2524
from casbin_async_sqlalchemy_adapter import Base
@@ -39,7 +38,7 @@ async def get_enforcer():
3938
adapter = Adapter(engine)
4039
await adapter.create_table()
4140

42-
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
41+
async_session = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
4342
async with async_session() as s:
4443
s.add(CasbinRule(ptype="p", v0="alice", v1="data1", v2="read"))
4544
s.add(CasbinRule(ptype="p", v0="bob", v1="data2", v2="write"))
@@ -72,7 +71,7 @@ class CustomRule(Base):
7271
async with engine.begin() as conn:
7372
await conn.run_sync(Base.metadata.create_all)
7473

75-
session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
74+
session = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
7675
async with session() as s:
7776
s.add(CustomRule(not_exist="NotNone"))
7877
await s.commit()
@@ -122,7 +121,7 @@ async def test_save_policy(self):
122121

123122
async def test_remove_policy(self):
124123
e = await get_enforcer()
125-
124+
126125
self.assertFalse(e.enforce("alice", "data5", "read"))
127126
await e.add_permission_for_user("alice", "data5", "read")
128127
self.assertTrue(e.enforce("alice", "data5", "read"))
@@ -137,7 +136,9 @@ async def test_remove_policies(self):
137136
await e.add_policies((("alice", "data5", "read"), ("alice", "data6", "read")))
138137
self.assertTrue(e.enforce("alice", "data5", "read"))
139138
self.assertTrue(e.enforce("alice", "data6", "read"))
140-
await e.remove_policies((("alice", "data5", "read"), ("alice", "data6", "read")))
139+
await e.remove_policies(
140+
(("alice", "data5", "read"), ("alice", "data6", "read"))
141+
)
141142
self.assertFalse(e.enforce("alice", "data5", "read"))
142143
self.assertFalse(e.enforce("alice", "data6", "read"))
143144

@@ -180,7 +181,7 @@ async def test_repr(self):
180181
self.assertEqual(repr(rule), '<CasbinRule None: "p, alice, data1, read">')
181182
engine = create_async_engine("sqlite+aiosqlite://", future=True)
182183

183-
session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
184+
session = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
184185
async with engine.begin() as conn:
185186
await conn.run_sync(Base.metadata.create_all)
186187
s = session()
@@ -370,7 +371,7 @@ async def test_update_policies(self):
370371

371372
self.assertFalse(e.enforce("data2_admin", "data2", "write"))
372373
self.assertTrue(e.enforce("data2_admin", "data_test", "write"))
373-
374+
374375
async def test_update_filtered_policies(self):
375376
e = await get_enforcer()
376377

@@ -392,5 +393,5 @@ async def test_update_filtered_policies(self):
392393
self.assertTrue(e.enforce("bob", "data2", "read"))
393394

394395

395-
if __name__ == '__main__':
396+
if __name__ == "__main__":
396397
unittest.main()

0 commit comments

Comments
 (0)