Skip to content

Commit e2b40ce

Browse files
feat: update pyromod patch with changes from pyromod v3.1.5 (#8)
* feat: bring changes from pyromod v3.1.5 * chore: fix whitespace issues around function parameters --------- Co-authored-by: Alisson Lauffer <alissonvitortc@gmail.com>
1 parent 16cf07c commit e2b40ce

9 files changed

Lines changed: 473 additions & 85 deletions

File tree

pyrogram/client.py

Lines changed: 207 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -336,11 +336,24 @@ async def listen(
336336
listener_type: ListenerTypes = ListenerTypes.MESSAGE,
337337
timeout: Optional[int] = None,
338338
unallowed_click_alert: bool = True,
339-
chat_id: int = None,
340-
user_id: int = None,
341-
message_id: int = None,
342-
inline_message_id: str = None,
339+
chat_id: Union[Union[int, str], List[Union[int, str]]] = None,
340+
user_id: Union[Union[int, str], List[Union[int, str]]] = None,
341+
message_id: Union[int, List[int]] = None,
342+
inline_message_id: Union[str, List[str]] = None,
343343
):
344+
"""
345+
Creates a listener and waits for it to be fulfilled.
346+
347+
:param filters: A filter to check if the listener should be fulfilled.
348+
:param listener_type: The type of listener to create. Defaults to :attr:`pyromod.types.ListenerTypes.MESSAGE`.
349+
:param timeout: The maximum amount of time to wait for the listener to be fulfilled. Defaults to ``None``.
350+
:param unallowed_click_alert: Whether to alert the user if they click on a button that is not intended for them. Defaults to ``True``.
351+
:param chat_id: The chat ID(s) to listen for. Defaults to ``None``.
352+
:param user_id: The user ID(s) to listen for. Defaults to ``None``.
353+
:param message_id: The message ID(s) to listen for. Defaults to ``None``.
354+
:param inline_message_id: The inline message ID(s) to listen for. Defaults to ``None``.
355+
:return: The Message or CallbackQuery that fulfilled the listener.
356+
"""
344357
pattern = Identifier(
345358
from_user_id=user_id,
346359
chat_id=chat_id,
@@ -350,15 +363,6 @@ async def listen(
350363

351364
loop = asyncio.get_event_loop()
352365
future = loop.create_future()
353-
future.add_done_callback(
354-
lambda f: self.stop_listening(
355-
listener_type,
356-
user_id=user_id,
357-
chat_id=chat_id,
358-
message_id=message_id,
359-
inline_message_id=inline_message_id,
360-
)
361-
)
362366

363367
listener = Listener(
364368
future=future,
@@ -368,33 +372,58 @@ async def listen(
368372
listener_type=listener_type,
369373
)
370374

375+
future.add_done_callback(lambda _future: self.remove_listener(listener))
376+
371377
self.listeners[listener_type].append(listener)
372378

373379
try:
374380
return await asyncio.wait_for(future, timeout)
375381
except asyncio.exceptions.TimeoutError:
376382
if callable(PyromodConfig.timeout_handler):
377-
PyromodConfig.timeout_handler(pattern, listener, timeout)
383+
if inspect.iscoroutinefunction(PyromodConfig.timeout_handler.__call__):
384+
await PyromodConfig.timeout_handler(pattern, listener, timeout)
385+
else:
386+
await self.loop.run_in_executor(
387+
None, PyromodConfig.timeout_handler, pattern, listener, timeout
388+
)
378389
elif PyromodConfig.throw_exceptions:
379390
raise ListenerTimeout(timeout)
380391

381392
async def ask(
382393
self,
383-
chat_id: int,
394+
chat_id: Union[Union[int, str], List[Union[int, str]]],
384395
text: str,
385396
filters: Optional[Filter] = None,
386397
listener_type: ListenerTypes = ListenerTypes.MESSAGE,
387398
timeout: Optional[int] = None,
388399
unallowed_click_alert: bool = True,
389-
user_id: int = None,
390-
message_id: int = None,
391-
inline_message_id: str = None,
400+
user_id: Union[Union[int, str], List[Union[int, str]]] = None,
401+
message_id: Union[int, List[int]] = None,
402+
inline_message_id: Union[str, List[str]] = None,
392403
*args,
393404
**kwargs,
394405
):
406+
"""
407+
Sends a message and waits for a response.
408+
409+
:param chat_id: The chat ID(s) to wait for a message from. The first chat ID will be used to send the message.
410+
:param text: The text to send.
411+
:param filters: Same as :meth:`pyromod.types.Client.listen`.
412+
:param listener_type: Same as :meth:`pyromod.types.Client.listen`.
413+
:param timeout: Same as :meth:`pyromod.types.Client.listen`.
414+
:param unallowed_click_alert: Same as :meth:`pyromod.types.Client.listen`.
415+
:param user_id: Same as :meth:`pyromod.types.Client.listen`.
416+
:param message_id: Same as :meth:`pyromod.types.Client.listen`.
417+
:param inline_message_id: Same as :meth:`pyromod.types.Client.listen`.
418+
:param args: Additional arguments to pass to :meth:`pyrogram.Client.send_message`.
419+
:param kwargs: Additional keyword arguments to pass to :meth:`pyrogram.Client.send_message`.
420+
:return:
421+
Same as :meth:`pyromod.types.Client.listen`. The sent message is returned as the attribute ``sent_message``.
422+
"""
395423
sent_message = None
396424
if text.strip() != "":
397-
sent_message = await self.send_message(chat_id, text, *args, **kwargs)
425+
chat_to_ask = chat_id[0] if isinstance(chat_id, list) else chat_id
426+
sent_message = await self.send_message(chat_to_ask, text, *args, **kwargs)
398427

399428
response = await self.listen(
400429
filters=filters,
@@ -411,12 +440,31 @@ async def ask(
411440

412441
return response
413442

414-
def get_matching_listener(
415-
self, pattern: Identifier, listener_type: ListenerTypes
443+
def remove_listener(self, listener: Listener):
444+
"""
445+
Removes a listener from the :meth:`pyromod.types.Client.listeners` dictionary.
446+
447+
:param listener: The listener to remove.
448+
:return: ``void``
449+
"""
450+
try:
451+
self.listeners[listener.listener_type].remove(listener)
452+
except ValueError:
453+
pass
454+
455+
def get_listener_matching_with_data(
456+
self, data: Identifier, listener_type: ListenerTypes
416457
) -> Optional[Listener]:
458+
"""
459+
Gets a listener that matches the given data.
460+
461+
:param data: A :class:`pyromod.types.Identifier` to match against.
462+
:param listener_type: The type of listener to get. Must be a value from :class:`pyromod.types.ListenerTypes`.
463+
:return: The listener that matches the given data or ``None`` if no listener matches.
464+
"""
417465
matching = []
418466
for listener in self.listeners[listener_type]:
419-
if listener.identifier.matches(pattern):
467+
if listener.identifier.matches(data):
420468
matching.append(listener)
421469

422470
# in case of multiple matching listeners, the most specific should be returned
@@ -425,44 +473,163 @@ def count_populated_attributes(listener_item: Listener):
425473

426474
return max(matching, key=count_populated_attributes, default=None)
427475

428-
def remove_listener(self, listener: Listener):
429-
self.listeners[listener.listener_type].remove(listener)
476+
def get_listener_matching_with_identifier_pattern(
477+
self, pattern: Identifier, listener_type: ListenerTypes
478+
) -> Optional[Listener]:
479+
"""
480+
Gets a listener that matches the given identifier pattern.
481+
482+
The difference from :meth:`pyromod.types.Client.get_listener_matching_with_data` is that this method
483+
intends to get a listener by passing partial info of the listener identifier, while the other method
484+
intends to get a listener by passing the full info of the update data, which the listener should match with.
430485
431-
def get_many_matching_listeners(
432-
self, pattern: Identifier, listener_type: ListenerTypes
486+
:param pattern: A :class:`pyromod.types.Identifier` to match against.
487+
:param listener_type: The type of listener to get. Must be a value from :class:`pyromod.types.ListenerTypes`.
488+
:return: The listener that matches the given identifier pattern or ``None`` if no listener matches.
489+
"""
490+
matching = []
491+
for listener in self.listeners[listener_type]:
492+
if pattern.matches(listener.identifier):
493+
matching.append(listener)
494+
495+
# in case of multiple matching listeners, the most specific should be returned
496+
497+
def count_populated_attributes(listener_item: Listener):
498+
return listener_item.identifier.count_populated()
499+
500+
return max(matching, key=count_populated_attributes, default=None)
501+
502+
def get_many_listeners_matching_with_data(
503+
self,
504+
data: Identifier,
505+
listener_type: ListenerTypes,
433506
) -> List[Listener]:
507+
"""
508+
Same of :meth:`pyromod.types.Client.get_listener_matching_with_data` but returns a list of listeners instead of one.
509+
510+
:param data: Same as :meth:`pyromod.types.Client.get_listener_matching_with_data`.
511+
:param listener_type: Same as :meth:`pyromod.types.Client.get_listener_matching_with_data`.
512+
:return: A list of listeners that match the given data.
513+
"""
434514
listeners = []
435515
for listener in self.listeners[listener_type]:
436-
if listener.identifier.matches(pattern):
516+
if listener.identifier.matches(data):
437517
listeners.append(listener)
438518
return listeners
439519

440-
def stop_listening(
520+
def get_many_listeners_matching_with_identifier_pattern(
521+
self,
522+
pattern: Identifier,
523+
listener_type: ListenerTypes,
524+
) -> List[Listener]:
525+
"""
526+
Same of :meth:`pyromod.types.Client.get_listener_matching_with_identifier_pattern` but returns a list of listeners instead of one.
527+
528+
:param pattern: Same as :meth:`pyromod.types.Client.get_listener_matching_with_identifier_pattern`.
529+
:param listener_type: Same as :meth:`pyromod.types.Client.get_listener_matching_with_identifier_pattern`.
530+
:return: A list of listeners that match the given identifier pattern.
531+
"""
532+
listeners = []
533+
for listener in self.listeners[listener_type]:
534+
if pattern.matches(listener.identifier):
535+
listeners.append(listener)
536+
return listeners
537+
538+
async def stop_listening(
441539
self,
442540
listener_type: ListenerTypes = ListenerTypes.MESSAGE,
443-
chat_id: int = None,
444-
user_id: int = None,
445-
message_id: int = None,
446-
inline_message_id: str = None,
541+
chat_id: Union[Union[int, str], List[Union[int, str]]] = None,
542+
user_id: Union[Union[int, str], List[Union[int, str]]] = None,
543+
message_id: Union[int, List[int]] = None,
544+
inline_message_id: Union[str, List[str]] = None,
447545
):
546+
"""
547+
Stops all listeners that match the given identifier pattern.
548+
Uses :meth:`pyromod.types.Client.get_many_listeners_matching_with_identifier_pattern`.
549+
550+
:param listener_type: The type of listener to stop. Must be a value from :class:`pyromod.types.ListenerTypes`.
551+
:param chat_id: The chat_id to match against.
552+
:param user_id: The user_id to match against.
553+
:param message_id: The message_id to match against.
554+
:param inline_message_id: The inline_message_id to match against.
555+
:return: ``void``
556+
"""
448557
pattern = Identifier(
449558
from_user_id=user_id,
450559
chat_id=chat_id,
451560
message_id=message_id,
452561
inline_message_id=inline_message_id,
453562
)
454-
listeners = self.get_many_matching_listeners(pattern, listener_type)
563+
listeners = self.get_many_listeners_matching_with_identifier_pattern(
564+
pattern, listener_type
565+
)
455566

456567
for listener in listeners:
457-
self.remove_listener(listener)
568+
await self.stop_listener(listener)
458569

459-
if listener.future.done():
460-
return
570+
async def stop_listener(self, listener: Listener):
571+
"""
572+
Stops a listener, calling stopped_handler if applicable or raising ListenerStopped if throw_exceptions is True.
461573
462-
if callable(PyromodConfig.stopped_handler):
463-
PyromodConfig.stopped_handler(pattern, listener)
464-
elif PyromodConfig.throw_exceptions:
465-
listener.future.set_exception(ListenerStopped())
574+
:param listener: The :class:`pyromod.types.Listener` to stop.
575+
:return: ``void``
576+
:raises ListenerStopped: If throw_exceptions is True.
577+
"""
578+
self.remove_listener(listener)
579+
580+
if listener.future.done():
581+
return
582+
583+
if callable(PyromodConfig.stopped_handler):
584+
if inspect.iscoroutinefunction(PyromodConfig.stopped_handler.__call__):
585+
await PyromodConfig.stopped_handler(None, listener)
586+
else:
587+
await self.loop.run_in_executor(
588+
None, PyromodConfig.stopped_handler, None, listener
589+
)
590+
elif PyromodConfig.throw_exceptions:
591+
listener.future.set_exception(ListenerStopped())
592+
593+
def register_next_step_handler(
594+
self,
595+
callback: Callable,
596+
filters: Optional[Filter] = None,
597+
listener_type: ListenerTypes = ListenerTypes.MESSAGE,
598+
unallowed_click_alert: bool = True,
599+
chat_id: Union[Union[int, str], List[Union[int, str]]] = None,
600+
user_id: Union[Union[int, str], List[Union[int, str]]] = None,
601+
message_id: Union[int, List[int]] = None,
602+
inline_message_id: Union[str, List[str]] = None,
603+
):
604+
"""
605+
Registers a listener with a callback to be called when the listener is fulfilled.
606+
607+
:param callback: The callback to call when the listener is fulfilled.
608+
:param filters: Same as :meth:`pyromod.types.Client.listen`.
609+
:param listener_type: Same as :meth:`pyromod.types.Client.listen`.
610+
:param unallowed_click_alert: Same as :meth:`pyromod.types.Client.listen`.
611+
:param chat_id: Same as :meth:`pyromod.types.Client.listen`.
612+
:param user_id: Same as :meth:`pyromod.types.Client.listen`.
613+
:param message_id: Same as :meth:`pyromod.types.Client.listen`.
614+
:param inline_message_id: Same as :meth:`pyromod.types.Client.listen`.
615+
:return: ``void``
616+
"""
617+
pattern = Identifier(
618+
from_user_id=user_id,
619+
chat_id=chat_id,
620+
message_id=message_id,
621+
inline_message_id=inline_message_id,
622+
)
623+
624+
listener = Listener(
625+
callback=callback,
626+
filters=filters,
627+
unallowed_click_alert=unallowed_click_alert,
628+
identifier=pattern,
629+
listener_type=listener_type,
630+
)
631+
632+
self.listeners[listener_type].append(listener)
466633

467634
async def authorize(self) -> User:
468635
if self.bot_token:

0 commit comments

Comments
 (0)