Skip to content

Commit 87a7e73

Browse files
committed
Permissions
1 parent b6a6b5a commit 87a7e73

3 files changed

Lines changed: 259 additions & 0 deletions

File tree

abstra_json_sql/ast.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,17 @@ class AndExpression(Expression):
4646
left: Expression
4747
right: Expression
4848

49+
@classmethod
50+
def from_list(cls, expressions: List[Expression]) -> "AndExpression":
51+
if len(expressions) == 0:
52+
return AndExpression(TrueExpression(), TrueExpression())
53+
elif len(expressions) == 1:
54+
return AndExpression(TrueExpression(), expressions[0])
55+
else:
56+
return AndExpression(
57+
expressions[0], AndExpression.from_list(expressions[1:])
58+
)
59+
4960

5061
@dataclass
5162
class OrExpression(Expression):
@@ -213,6 +224,15 @@ class Select(Command):
213224
order_part: Optional[OrderBy] = None
214225
limit_part: Optional[Limit] = None
215226

227+
def get_tables(self) -> List[str]:
228+
tables = []
229+
if self.from_part:
230+
tables.append(self.from_part.table)
231+
if self.from_part.join:
232+
for join in self.from_part.join:
233+
tables.append(join.table)
234+
return tables
235+
216236

217237
@dataclass
218238
class WithPart(Ast):

abstra_json_sql/authorization.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
from dataclasses import dataclass
2+
from typing import List, Literal, Optional
3+
4+
from .ast import (
5+
AndExpression,
6+
Command,
7+
Delete,
8+
EqualExpression,
9+
Expression,
10+
Insert,
11+
NameExpression,
12+
Select,
13+
Update,
14+
)
15+
from .lexer import scan
16+
from .parser import parse, parse_expression
17+
from .tables import ITablesSnapshot
18+
19+
RuleCommand = Literal["SELECT", "INSERT", "UPDATE", "DELETE"]
20+
21+
22+
def includes_expression(main: Expression, sub: Expression) -> bool:
23+
if main == sub:
24+
return True
25+
elif isinstance(main, AndExpression):
26+
return includes_expression(main.left, sub) or includes_expression(
27+
main.right, sub
28+
)
29+
return False
30+
31+
32+
def validate_insert_condition(condition: Expression) -> bool:
33+
if isinstance(condition, AndExpression):
34+
return validate_insert_condition(condition.left) and validate_insert_condition(
35+
condition.right
36+
)
37+
elif isinstance(condition, EqualExpression):
38+
if isinstance(condition.left, NameExpression):
39+
return True
40+
return False
41+
42+
43+
@dataclass
44+
class Rule:
45+
type: Literal["GRANT", "REVOKE"]
46+
command: RuleCommand
47+
table_name: str
48+
condition: Optional[str] = None
49+
50+
def __post_init__(self):
51+
if self.command == "INSERT" and self.condition is not None:
52+
cond_exp, _ = parse_expression(scan(self.condition))
53+
if not validate_insert_condition(cond_exp):
54+
raise NotImplementedError(
55+
"Only simple equality conditions are supported for INSERT rules."
56+
)
57+
58+
def condition_met(self, command: Command) -> bool:
59+
if self.condition is None:
60+
return True
61+
if isinstance(command, Select) and self.command == "SELECT":
62+
cond_exp, _ = parse_expression(scan(self.condition))
63+
if command.where_part is None:
64+
return False
65+
return includes_expression(command.where_part.expression, cond_exp)
66+
elif isinstance(command, Update) and self.command == "UPDATE":
67+
cond_exp, _ = parse_expression(scan(self.condition))
68+
if command.where is None:
69+
return False
70+
return includes_expression(command.where.expression, cond_exp)
71+
elif isinstance(command, Delete) and self.command == "DELETE":
72+
cond_exp, _ = parse_expression(scan(self.condition))
73+
if command.where is None:
74+
return False
75+
return includes_expression(command.where.expression, cond_exp)
76+
elif isinstance(command, Insert) and self.command == "INSERT":
77+
cond_exp, _ = parse_expression(scan(self.condition))
78+
if command.columns is None:
79+
return False
80+
81+
for value in command.values:
82+
insert_expression = AndExpression.from_list(
83+
[
84+
EqualExpression(NameExpression(field), val)
85+
for field, val in zip(command.columns, value)
86+
]
87+
)
88+
if includes_expression(insert_expression, cond_exp):
89+
return True
90+
return False
91+
92+
def check(self, command: Command) -> Literal["ALLOW", "DENY", "NO_MATCH"]:
93+
if isinstance(command, Select) and self.command == "SELECT":
94+
if self.table_name in command.get_tables():
95+
if self.condition_met(command):
96+
return "ALLOW" if self.type == "GRANT" else "DENY"
97+
else:
98+
return "NO_MATCH"
99+
elif isinstance(command, Insert) and self.command == "INSERT":
100+
if self.table_name == command.table_name:
101+
if self.condition_met(command):
102+
return "ALLOW" if self.type == "GRANT" else "DENY"
103+
else:
104+
return "NO_MATCH"
105+
elif isinstance(command, Update) and self.command == "UPDATE":
106+
if self.table_name == command.table_name:
107+
if self.condition_met(command):
108+
return "ALLOW" if self.type == "GRANT" else "DENY"
109+
else:
110+
return "NO_MATCH"
111+
elif isinstance(command, Delete) and self.command == "DELETE":
112+
if self.table_name == command.table_name:
113+
if self.condition_met(command):
114+
return "ALLOW" if self.type == "GRANT" else "DENY"
115+
else:
116+
return "NO_MATCH"
117+
return "NO_MATCH"
118+
119+
120+
class Permissions:
121+
default: bool
122+
rules: List[Rule]
123+
124+
def __init__(self, default: bool = False):
125+
self.default = default
126+
self.rules = []
127+
128+
def grant(
129+
self, command: RuleCommand, table_name: str, condition: Optional[str]
130+
) -> "ITablesSnapshot":
131+
if condition is None:
132+
133+
def condition(_):
134+
return self.default
135+
136+
self.rules.append(Rule("GRANT", command, table_name, condition))
137+
return self
138+
139+
def revoke(
140+
self, command: RuleCommand, table_name: str, condition: Optional[str]
141+
) -> "ITablesSnapshot":
142+
if condition is None:
143+
144+
def condition(_):
145+
return not self.default
146+
147+
self.rules.append(Rule("REVOKE", command, table_name, condition))
148+
return self
149+
150+
def allowed(self, sql: str):
151+
cmd = parse(scan(sql))
152+
allowed = self.default
153+
for rule in self.rules:
154+
result = rule.check(cmd)
155+
if result == "ALLOW":
156+
allowed = True
157+
elif result == "DENY":
158+
allowed = False
159+
return allowed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from unittest import TestCase
2+
3+
from .authorization import Permissions
4+
5+
6+
class TestPermissions(TestCase):
7+
def test_select(self):
8+
p = Permissions(default=False)
9+
p.grant("SELECT", "users", "age > 18")
10+
11+
self.assertFalse(p.allowed("select * from users"))
12+
self.assertTrue(p.allowed("select * from users where age > 18"))
13+
14+
p.revoke("SELECT", "users", "age < 21")
15+
self.assertFalse(p.allowed("select * from users where age < 21"))
16+
17+
def test_select_complex(self):
18+
p = Permissions(default=False)
19+
p.grant("SELECT", "orders", "status = 'completed' AND total > 100")
20+
21+
self.assertFalse(p.allowed("select * from orders where status = 'pending'"))
22+
self.assertFalse(p.allowed("select * from orders where total <= 100"))
23+
self.assertTrue(
24+
p.allowed("select * from orders where status = 'completed' AND total > 100")
25+
)
26+
27+
p.revoke("SELECT", "orders", "customer_id = 42")
28+
self.assertFalse(
29+
p.allowed(
30+
"select * from orders where customer_id = 42 AND status = 'completed' AND total > 100"
31+
)
32+
)
33+
34+
def test_insert(self):
35+
p = Permissions(default=False)
36+
p.grant("INSERT", "users", "name = 'Alice'")
37+
38+
self.assertFalse(p.allowed("insert into users (name, age) values ('Bob', 25)"))
39+
self.assertTrue(p.allowed("insert into users (name, age) values ('Alice', 30)"))
40+
41+
p.revoke("INSERT", "users", "age = 18")
42+
self.assertFalse(
43+
p.allowed("insert into users (name, age) values ('Alice', 18)")
44+
)
45+
self.assertTrue(p.allowed("insert into users (name, age) values ('Alice', 20)"))
46+
47+
def test_update(self):
48+
p = Permissions(default=False)
49+
p.grant("UPDATE", "users", "age < 65")
50+
51+
self.assertFalse(p.allowed("update users set age = age + 1 where age >= 65"))
52+
self.assertTrue(p.allowed("update users set age = age + 1 where age < 65"))
53+
54+
p.revoke("UPDATE", "users", "status = 'inactive'")
55+
self.assertFalse(
56+
p.allowed("update users set age = age + 1 where status = 'inactive'")
57+
)
58+
self.assertFalse(
59+
p.allowed("update users set age = age + 1 where status = 'active'")
60+
)
61+
62+
p.grant("UPDATE", "users", "status = 'active'")
63+
self.assertTrue(
64+
p.allowed("update users set age = age + 1 where status = 'active'")
65+
)
66+
67+
def test_delete(self):
68+
p = Permissions(default=False)
69+
p.grant("DELETE", "users", "status = 'inactive'")
70+
71+
self.assertFalse(p.allowed("delete from users where status = 'active'"))
72+
self.assertTrue(p.allowed("delete from users where status = 'inactive'"))
73+
74+
p.revoke("DELETE", "users", "age < 18")
75+
self.assertFalse(
76+
p.allowed("delete from users where status = 'inactive' and age < 18")
77+
)
78+
self.assertTrue(
79+
p.allowed("delete from users where status = 'inactive' and age >= 18")
80+
)

0 commit comments

Comments
 (0)