From eb7ed5dd31a5491d7e05d1a6d2e48e3c89b43463 Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov <109198731+Jahongir-Qurbonov@users.noreply.github.com> Date: Mon, 19 Aug 2024 02:01:36 +0500 Subject: [PATCH] Support asynchronous filters (#4434) --- client/src/telethon/_impl/client/client/updates.py | 5 +++-- .../src/telethon/_impl/client/events/filters/combinators.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/client/src/telethon/_impl/client/client/updates.py b/client/src/telethon/_impl/client/client/updates.py index 591280c9..00096378 100644 --- a/client/src/telethon/_impl/client/client/updates.py +++ b/client/src/telethon/_impl/client/client/updates.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio from collections.abc import Awaitable, Callable +from inspect import isawaitable from typing import TYPE_CHECKING, Any, Optional, Sequence, Type, TypeVar from ...session import Gap @@ -23,7 +24,7 @@ def on( self: Client, event_cls: Type[Event], /, filter: Optional[Filter] = None ) -> Callable[[Callable[[Event], Awaitable[Any]]], Callable[[Event], Awaitable[Any]]]: def wrapper( - handler: Callable[[Event], Awaitable[Any]] + handler: Callable[[Event], Awaitable[Any]], ) -> Callable[[Event], Awaitable[Any]]: add_event_handler(self, handler, event_cls, filter) return handler @@ -145,7 +146,7 @@ async def dispatch_next(client: Client) -> None: for event_cls, handlers in client._handlers.items(): if event := event_cls._try_from_update(client, update, chat_map): for handler, filter in handlers: - if not filter or filter(event): + if not filter or (await r if isawaitable(r := filter(event)) else r): ret = await handler(event) if not (ret is Continue or client._check_all_handlers): return diff --git a/client/src/telethon/_impl/client/events/filters/combinators.py b/client/src/telethon/_impl/client/events/filters/combinators.py index a6bc99d9..10b59187 100644 --- a/client/src/telethon/_impl/client/events/filters/combinators.py +++ b/client/src/telethon/_impl/client/events/filters/combinators.py @@ -1,11 +1,11 @@ import abc import typing from collections.abc import Callable -from typing import TypeAlias +from typing import Awaitable, TypeAlias from ..event import Event -Filter: TypeAlias = Callable[[Event], bool] +Filter: TypeAlias = Callable[[Event], bool | Awaitable[bool]] class Combinable(abc.ABC):