Skip to content

Commit ab77631

Browse files
authored
Merge pull request #1 from sampingantech/fix/delete-policy
Fix/delete policy
2 parents 9bad463 + 68d8072 commit ab77631

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

casbin_databases_adapter/adapter.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ class Filter:
1818

1919

2020
class DatabasesAdapter(persist.Adapter):
21+
22+
cols = ["ptype"] + [f"v{i}" for i in range(6)]
23+
2124
def __init__(self, db: Database, table: Table, filtered=False):
2225
self.db: Database = db
2326
self.table: Table = table
@@ -29,7 +32,7 @@ async def load_policy(self, model: Model):
2932
rows = await self.db.fetch_all(query)
3033
for row in rows:
3134
# convert row from tuple to csv format and removing the first column (id)
32-
line = [i for i in row[1:] if i]
35+
line = [v for k, v in row.items() if k in self.cols and v is not None]
3336
persist.load_policy_line(", ".join(line), model)
3437

3538
@to_sync()
@@ -68,7 +71,7 @@ async def remove_policy(self, sec, p_type, rule):
6871

6972
@to_sync()
7073
async def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
71-
query = self.table.select().where(self.table.columns.ptype == ptype)
74+
query = self.table.delete().where(self.table.columns.ptype == ptype)
7275
if not (0 <= field_index <= 5):
7376
return False
7477
if not (1 <= field_index + len(field_values) <= 6):
@@ -77,7 +80,7 @@ async def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
7780
if len(value) > 0:
7881
query = query.where(self.table.columns[f"v{field_index+1}"] == value)
7982
result = await self.db.execute(query)
80-
return True if result > 0 else False
83+
return True if result else False
8184

8285
@to_sync()
8386
async def load_filtered_policy(self, model: Model, filter_: Filter) -> None:
@@ -88,7 +91,7 @@ async def load_filtered_policy(self, model: Model, filter_: Filter) -> None:
8891
rows = await self.db.fetch_all(query)
8992
for row in rows:
9093
# convert row from tuple to csv format and removing the first column (id)
91-
line = [i for i in row[1:] if i]
94+
line = [v for k, v in row.items() if k in self.cols and v is not None]
9295
persist.load_policy_line(", ".join(line), model)
9396

9497
def is_filtered(self):

0 commit comments

Comments
 (0)