diff --git a/telethon/client/updates.py b/telethon/client/updates.py index 854860fd..10eb8a77 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -1,4 +1,5 @@ import asyncio +import inspect import itertools import random import time @@ -424,7 +425,10 @@ class UpdateMethods: if not builder.resolved: await builder.resolve(self) - if not builder.filter(event): + filter = builder.filter(event) + if inspect.isawaitable(filter): + filter = await filter + if not filter: continue try: diff --git a/telethon/events/callbackquery.py b/telethon/events/callbackquery.py index c35e348b..d1558d21 100644 --- a/telethon/events/callbackquery.py +++ b/telethon/events/callbackquery.py @@ -118,8 +118,10 @@ class CallbackQuery(EventBuilder): elif event.query.data != self.match: return - if not self.func or self.func(event): - return event + if self.func: + # Return the result of func directly as it may need to be awaited + return self.func(event) + return True class Event(EventCommon, SenderGetter): """ diff --git a/telethon/events/common.py b/telethon/events/common.py index 42586608..f5ff5b50 100644 --- a/telethon/events/common.py +++ b/telethon/events/common.py @@ -55,7 +55,7 @@ class EventBuilder(abc.ABC): which will be ignored if ``blacklist_chats=True``. func (`callable`, optional): - A callable function that should accept the event as input + A callable (async or not) function that should accept the event as input parameter, and return a value indicating whether the event should be dispatched or not (any truthy value will do, it does not need to be a `bool`). It works like a custom filter: @@ -105,13 +105,13 @@ class EventBuilder(abc.ABC): def filter(self, event): """ - If the ID of ``event._chat_peer`` isn't in the chats set (or it is - but the set is a blacklist) returns `None`, otherwise the event. + Returns a truthy value if the event passed the filter and should be + used, or falsy otherwise. The return value may need to be awaited. The events must have been resolved before this can be called. """ if not self.resolved: - return None + return if self.chats is not None: # Note: the `event.chat_id` property checks if it's `None` for us @@ -119,10 +119,13 @@ class EventBuilder(abc.ABC): if inside == self.blacklist_chats: # If this chat matches but it's a blacklist ignore. # If it doesn't match but it's a whitelist ignore. - return None + return - if not self.func or self.func(event): - return event + if not self.func: + return True + + # Return the result of func directly as it may need to be awaited + return self.func(event) class EventCommon(ChatGetter, abc.ABC): diff --git a/telethon/events/raw.py b/telethon/events/raw.py index 912a934d..84910778 100644 --- a/telethon/events/raw.py +++ b/telethon/events/raw.py @@ -46,6 +46,8 @@ class Raw(EventBuilder): return update def filter(self, event): - if ((not self.types or isinstance(event, self.types)) - and (not self.func or self.func(event))): + if not self.types or isinstance(event, self.types): + if self.func: + # Return the result of func directly as it may need to be awaited + return self.func(event) return event diff --git a/telethon/tl/custom/conversation.py b/telethon/tl/custom/conversation.py index c0d59e03..46b67aa8 100644 --- a/telethon/tl/custom/conversation.py +++ b/telethon/tl/custom/conversation.py @@ -1,5 +1,6 @@ import asyncio import functools +import inspect import itertools import time @@ -312,9 +313,15 @@ class Conversation(ChatGetter): for key, (ev, fut) in list(self._custom.items()): ev_type = type(ev) inst = built[ev_type] - if inst and ev.filter(inst): - fut.set_result(inst) - del self._custom[key] + + if inst: + filter = ev.filter(inst) + if inspect.isawaitable(filter): + filter = await filter + + if filter: + fut.set_result(inst) + del self._custom[key] def _on_new_message(self, response): response = response.message