Skip to content

Commit 0be76d3

Browse files
committed
fix: migrate stale BannedUsers.db schema and rename to snake_case
Existing databases kept PascalCase table names (Banlist, Bangroups) and the old groupName column because CREATE TABLE IF NOT EXISTS doesn't alter existing tables. Add _migrate_schema() to rename tables and columns on startup, rename all SQL references to snake_case, rename the DB file from BannedUsers.db to banned_users.db, and remove the dead __regex_return_unescaped__ method.
1 parent 9834b20 commit 0be76d3

3 files changed

Lines changed: 70 additions & 20 deletions

File tree

ban_list.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@ def __init__(self, filename):
3737
self.conn = sqlite3.connect(filename)
3838
self.cursor = self.conn.cursor()
3939

40+
self._migrate_schema()
41+
4042
# Create table for bans
4143
self.cursor.execute("""
42-
CREATE TABLE IF NOT EXISTS Banlist(group_name TEXT, pattern TEXT,
44+
CREATE TABLE IF NOT EXISTS ban_list(group_name TEXT, pattern TEXT,
4345
ban_reason TEXT,
4446
timestamp INTEGER, banlength INTEGER
4547
)
@@ -49,11 +51,28 @@ def __init__(self, filename):
4951
# This will be used to check if a group exists
5052
# when checking if a user is banned in that group.
5153
self.cursor.execute("""
52-
CREATE TABLE IF NOT EXISTS Bangroups(group_name TEXT)
54+
CREATE TABLE IF NOT EXISTS ban_groups(group_name TEXT)
5355
""")
5456

5557
self.define_group("Global")
5658

59+
def _migrate_schema(self):
60+
"""Migrate pre-snake_case database schema."""
61+
tables = {row[0] for row in self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()}
62+
63+
for old, new in [("Banlist", "ban_list"), ("Bangroups", "ban_groups")]:
64+
if old in tables and new not in tables:
65+
self.cursor.execute(f"ALTER TABLE {old} RENAME TO {new}")
66+
logger.info("Migrated table %s -> %s", old, new)
67+
68+
for table in ["ban_list", "ban_groups"]:
69+
columns = [row[1] for row in self.cursor.execute(f"PRAGMA table_info({table})").fetchall()]
70+
if "groupName" in columns and "group_name" not in columns:
71+
self.cursor.execute(f"ALTER TABLE {table} RENAME COLUMN groupName TO group_name")
72+
logger.info("Migrated column %s.groupName -> group_name", table)
73+
74+
self.conn.commit()
75+
5776
# You need to define a group name if you want
5877
# to have your own ban groups.
5978
# This should prevent accidents in which an user
@@ -64,7 +83,7 @@ def define_group(self, group_name):
6483
if not does_exist:
6584
self.cursor.execute(
6685
"""
67-
INSERT INTO Bangroups(group_name)
86+
INSERT INTO ban_groups(group_name)
6887
VALUES (?)
6988
""",
7089
(group_name,),
@@ -119,14 +138,14 @@ def unban_user(self, user, ident="*", host="*", group_name="Global"):
119138

120139
def clear_all_bans(self):
121140
self.cursor.execute("""
122-
DELETE FROM Banlist
141+
DELETE FROM ban_list
123142
""")
124143
self.conn.commit()
125144

126145
def clear_group_bans(self, group_name):
127146
self.cursor.execute(
128147
"""
129-
DELETE FROM Banlist
148+
DELETE FROM ban_list
130149
WHERE group_name = ?
131150
""",
132151
(group_name,),
@@ -137,12 +156,12 @@ def get_bans(self, group_name=None, matching_string=None):
137156
if group_name is None:
138157
if matching_string is None:
139158
self.cursor.execute("""
140-
SELECT * FROM Banlist
159+
SELECT * FROM ban_list
141160
""")
142161
else:
143162
self.cursor.execute(
144163
"""
145-
SELECT * FROM Banlist
164+
SELECT * FROM ban_list
146165
WHERE ? GLOB pattern
147166
""",
148167
(matching_string.lower(),),
@@ -155,15 +174,15 @@ def get_bans(self, group_name=None, matching_string=None):
155174
if matching_string is None:
156175
self.cursor.execute(
157176
"""
158-
SELECT * FROM Banlist
177+
SELECT * FROM ban_list
159178
WHERE group_name = ?
160179
""",
161180
(group_name,),
162181
)
163182
else:
164183
self.cursor.execute(
165184
"""
166-
SELECT * FROM Banlist
185+
SELECT * FROM ban_list
167186
WHERE group_name = ? AND ? GLOB pattern
168187
""",
169188
(group_name, matching_string.lower()),
@@ -183,7 +202,7 @@ def check_ban(self, user, ident, host, group_name="Global"):
183202

184203
self.cursor.execute(
185204
"""
186-
SELECT * FROM Banlist
205+
SELECT * FROM ban_list
187206
WHERE group_name = ? AND ? GLOB pattern
188207
""",
189208
(group_name, banstring),
@@ -198,7 +217,7 @@ def check_ban(self, user, ident, host, group_name="Global"):
198217

199218
def get_groups(self):
200219
self.cursor.execute("""
201-
SELECT group_name FROM Bangroups
220+
SELECT group_name FROM ban_groups
202221
""")
203222

204223
group_tuples = self.cursor.fetchall()
@@ -240,9 +259,6 @@ def unescape_banstring(self, banstring):
240259

241260
return finstring.getvalue()
242261

243-
def __regex_return_unescaped__(self, match):
244-
pass
245-
246262
def _ban(
247263
self,
248264
banstring,
@@ -253,7 +269,7 @@ def _ban(
253269
):
254270
self.cursor.execute(
255271
"""
256-
INSERT INTO Banlist(group_name, pattern, ban_reason, timestamp, banlength)
272+
INSERT INTO ban_list(group_name, pattern, ban_reason, timestamp, banlength)
257273
VALUES (?, ?, ?, ?, ?)
258274
""",
259275
(group_name, banstring, ban_reason, timestamp, banlength),
@@ -264,7 +280,7 @@ def _ban(
264280
def _unban(self, banstring, group_name="Global"):
265281
self.cursor.execute(
266282
"""
267-
DELETE FROM Banlist
283+
DELETE FROM ban_list
268284
WHERE group_name = ? AND pattern = ?
269285
""",
270286
(group_name, banstring),
@@ -275,7 +291,7 @@ def _unban(self, banstring, group_name="Global"):
275291
def _ban_exists(self, group_name, banstring):
276292
self.cursor.execute(
277293
"""
278-
SELECT 1 FROM Banlist
294+
SELECT 1 FROM ban_list
279295
WHERE group_name = ? AND pattern = ?
280296
""",
281297
(group_name, banstring),
@@ -288,7 +304,7 @@ def _ban_exists(self, group_name, banstring):
288304
def _group_exists(self, group_name):
289305
self.cursor.execute(
290306
"""
291-
SELECT 1 FROM Bangroups
307+
SELECT 1 FROM ban_groups
292308
WHERE group_name = ?
293309
""",
294310
(group_name,),

command_router.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ def __init__(self, channels, cmdprefix, name, ident, adminlist, loglevel):
5555
self._handler_lock = asyncio.Lock()
5656

5757
self.task_pool = task_pool.TaskPool()
58-
self.ban_list = BanList("BannedUsers.db")
58+
if os.path.exists("BannedUsers.db") and not os.path.exists("banned_users.db"):
59+
os.rename("BannedUsers.db", "banned_users.db")
60+
self.ban_list = BanList("banned_users.db")
5961

6062
self.helper = HelpModule()
6163
self.auth = None

tests/test_ban_list.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import sqlite3
2+
13
import pytest
24

3-
from ban_list import InvalidCharacterUsed, NoSuchBanGroup
5+
from ban_list import BanList, InvalidCharacterUsed, NoSuchBanGroup
46

57

68
class TestBanListGroups:
@@ -141,3 +143,33 @@ def test_create_sql_pattern_and_unescape_roundtrip(self, ban_list):
141143

142144
def test_unescape_plain_string(self, ban_list):
143145
assert ban_list.unescape_banstring("plainuser") == "plainuser"
146+
147+
148+
class TestBanListMigration:
149+
def test_migrate_old_schema(self, tmp_path):
150+
"""BanList migrates PascalCase tables and groupName columns."""
151+
db_path = tmp_path / "legacy.db"
152+
conn = sqlite3.connect(db_path)
153+
cur = conn.cursor()
154+
cur.execute(
155+
"CREATE TABLE Banlist(groupName TEXT, pattern TEXT, ban_reason TEXT, timestamp INTEGER, banlength INTEGER)"
156+
)
157+
cur.execute("INSERT INTO Banlist VALUES ('Global', 'bad!*@*', 'test', -1, -1)")
158+
cur.execute("CREATE TABLE Bangroups(groupName TEXT)")
159+
cur.execute("INSERT INTO Bangroups VALUES ('Global')")
160+
conn.commit()
161+
conn.close()
162+
163+
bl = BanList(db_path)
164+
165+
# Old tables should no longer exist
166+
tables = {row[0] for row in bl.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()}
167+
assert "Banlist" not in tables
168+
assert "Bangroups" not in tables
169+
assert "ban_list" in tables
170+
assert "ban_groups" in tables
171+
172+
# Data should be accessible via the new schema
173+
assert bl.get_groups() == ["Global"]
174+
is_banned, _ = bl.check_ban("bad", "*", "*")
175+
assert is_banned is True

0 commit comments

Comments
 (0)