@@ -19,28 +19,46 @@ class Base(DeclarativeBase):
1919 pass
2020
2121
22- class CasbinRule (Base ):
23- __tablename__ = "casbin_rule"
24-
25- id = Column (Integer , primary_key = True )
26- ptype = Column (String (255 ))
27- v0 = Column (String (255 ))
28- v1 = Column (String (255 ))
29- v2 = Column (String (255 ))
30- v3 = Column (String (255 ))
31- v4 = Column (String (255 ))
32- v5 = Column (String (255 ))
33-
34- def __str__ (self ):
35- arr = [self .ptype ]
36- for v in (self .v0 , self .v1 , self .v2 , self .v3 , self .v4 , self .v5 ):
37- if v is None :
38- break
39- arr .append (v )
40- return ", " .join (arr )
22+ def create_casbin_rule_class (table_name ):
23+ """
24+ Factory function to create a CasbinRule class with a custom table name.
25+
26+ Args:
27+ table_name (str): Table name for the CasbinRule class.
28+
29+ Returns:
30+ db_class (CasbinRule): The CasbinRule class.
31+ """
32+
33+ class CasbinRule (Base ):
34+ __tablename__ = table_name
35+ __table_args__ = {"extend_existing" : True }
36+
37+ id = Column (Integer , primary_key = True )
38+ ptype = Column (String (255 ))
39+ v0 = Column (String (255 ))
40+ v1 = Column (String (255 ))
41+ v2 = Column (String (255 ))
42+ v3 = Column (String (255 ))
43+ v4 = Column (String (255 ))
44+ v5 = Column (String (255 ))
45+
46+ def __str__ (self ):
47+ arr = [self .ptype ]
48+ for v in (self .v0 , self .v1 , self .v2 , self .v3 , self .v4 , self .v5 ):
49+ if v is None :
50+ break
51+ arr .append (v )
52+ return ", " .join (arr )
4153
42- def __repr__ (self ):
43- return '<CasbinRule {}: "{}">' .format (self .id , str (self ))
54+ def __repr__ (self ):
55+ return '<CasbinRule {}: "{}">' .format (self .id , str (self ))
56+
57+ return CasbinRule
58+
59+
60+ # Export the default CasbinRule class with table name 'casbin_rule'.
61+ CasbinRule = create_casbin_rule_class ("casbin_rule" )
4462
4563
4664class Filter :
@@ -56,14 +74,20 @@ class Filter:
5674class Adapter (persist .Adapter , persist .adapters .UpdateAdapter ):
5775 """the interface for Casbin adapters."""
5876
59- def __init__ (self , engine , db_class = None , filtered = False ):
77+ def __init__ (
78+ self ,
79+ engine ,
80+ db_class = None ,
81+ table_name = "casbin_rule" ,
82+ filtered = False ,
83+ ):
6084 if isinstance (engine , str ):
6185 self ._engine = create_engine (engine )
6286 else :
6387 self ._engine = engine
6488
6589 if db_class is None :
66- db_class = CasbinRule
90+ db_class = create_casbin_rule_class ( table_name = table_name )
6791 else :
6892 for attr in (
6993 "id" ,
@@ -281,7 +305,6 @@ def _update_filtered_policies(self, new_rules, filter) -> [[str]]:
281305 """_update_filtered_policies updates all the policies on the basis of the filter."""
282306
283307 with self ._session_scope () as session :
284-
285308 # Load old policies
286309
287310 query = session .query (self ._db_class ).filter (
0 commit comments