Skip to content

Commit e3075f4

Browse files
Improve stats cog documentation & type annotations (#503)
1 parent deb3a8f commit e3075f4

3 files changed

Lines changed: 14 additions & 15 deletions

File tree

cogs/stats/__init__.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,14 @@
88

99
from config import settings
1010
from db.core.models import LeftDiscordMember
11-
from utils import (
12-
CommandChecks,
13-
TeXBotBaseCog,
14-
)
11+
from utils import CommandChecks, TeXBotBaseCog
1512
from utils.error_capture_decorators import capture_guild_does_not_exist_error
1613

1714
from .counts import get_channel_message_counts, get_server_message_counts
1815
from .graphs import amount_of_time_formatter, plot_bar_chart
1916

2017
if TYPE_CHECKING:
21-
from collections.abc import AsyncIterable, Sequence
18+
from collections.abc import AsyncIterable, Mapping, Sequence
2219
from typing import Final
2320

2421
from utils import TeXBotApplicationContext
@@ -115,7 +112,7 @@ async def channel_stats(
115112

116113
await ctx.defer(ephemeral=True)
117114

118-
message_counts: dict[str, int] = await get_channel_message_counts(channel=channel)
115+
message_counts: Mapping[str, int] = await get_channel_message_counts(channel=channel)
119116

120117
if math.ceil(max(message_counts.values()) / 15) < 1:
121118
await self.command_send_error(
@@ -167,7 +164,7 @@ async def server_stats(self, ctx: "TeXBotApplicationContext") -> None:
167164

168165
await ctx.defer(ephemeral=True)
169166

170-
message_counts: dict[str, dict[str, int]] = await get_server_message_counts(
167+
message_counts: Mapping[str, Mapping[str, int]] = await get_server_message_counts(
171168
guild=main_guild, guest_role=guest_role
172169
)
173170

cogs/stats/counts.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
from config import settings
88

99
if TYPE_CHECKING:
10-
from collections.abc import AsyncIterable, Sequence
10+
from collections.abc import AsyncIterable, Mapping, Sequence
1111

1212

1313
__all__: "Sequence[str]" = ("get_channel_message_counts", "get_server_message_counts")
1414

1515

16-
async def get_channel_message_counts(channel: discord.TextChannel) -> dict[str, int]:
16+
async def get_channel_message_counts(channel: discord.TextChannel) -> "Mapping[str, int]":
1717
"""
1818
Get the message counts for each role in the given channel.
1919
@@ -62,14 +62,14 @@ async def get_channel_message_counts(channel: discord.TextChannel) -> dict[str,
6262

6363
async def get_server_message_counts(
6464
guild: discord.Guild, *, guest_role: discord.Role
65-
) -> dict[str, dict[str, int]]:
65+
) -> "Mapping[str, Mapping[str, int]]":
6666
"""
6767
Get the message counts for each channel in the given server.
6868
69-
The message counts are stored in a mapping. It contains a key "roles" which is
69+
The message counts are stored in a mapping. It contains a key "roles", which is
7070
a mapping of role names (prefixed by `@`) to the message counts
7171
for each role across the entire server.
72-
The mapping also contains a key "channels" which is a mapping with the channel
72+
The mapping also contains a key "channels", which is a mapping with the channel
7373
name as a key and the number of messages sent in that channel as the value.
7474
The "roles" sub-mapping also includes a "Total" key for the total number of messages.
7575
"""

cogs/stats/graphs.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import mplcyberpunk
1010

1111
if TYPE_CHECKING:
12-
from collections.abc import Collection, Sequence
12+
from collections.abc import Collection, Mapping, Sequence
1313

1414
from matplotlib.text import Text as Plot_Text
1515

@@ -37,19 +37,21 @@ def amount_of_time_formatter(value: float, time_scale: str) -> str:
3737

3838

3939
def plot_bar_chart(
40-
data: dict[str, int],
40+
data: "Mapping[str, int]",
4141
x_label: str,
4242
y_label: str,
4343
title: str,
4444
filename: str,
4545
description: str,
4646
extra_text: str = "",
4747
) -> discord.File:
48-
"""Generate an image of a plot bar chart from the given data & format variables."""
48+
"""Generate an image of a plot bar chart from the given data and format variables."""
4949
matplotlib.pyplot.style.use("cyberpunk")
5050

5151
max_data_value: int = max(data.values()) + 1
5252

53+
data = dict(data)
54+
5355
# NOTE: The "extra_values" dictionary represents columns of data that should be formatted differently to the standard data columns
5456
extra_values: dict[str, int] = {}
5557
if "Total" in data:

0 commit comments

Comments
 (0)