Support asynchronous filters (#4434)

This commit is contained in:
Jahongir Qurbonov 2024-08-19 02:01:36 +05:00 committed by GitHub
parent 2ab4bed02d
commit eb7ed5dd31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 4 deletions

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Optional, Sequence, Type, TypeVar from typing import TYPE_CHECKING, Any, Optional, Sequence, Type, TypeVar
from ...session import Gap from ...session import Gap
@ -23,7 +24,7 @@ def on(
self: Client, event_cls: Type[Event], /, filter: Optional[Filter] = None self: Client, event_cls: Type[Event], /, filter: Optional[Filter] = None
) -> Callable[[Callable[[Event], Awaitable[Any]]], Callable[[Event], Awaitable[Any]]]: ) -> Callable[[Callable[[Event], Awaitable[Any]]], Callable[[Event], Awaitable[Any]]]:
def wrapper( def wrapper(
handler: Callable[[Event], Awaitable[Any]] handler: Callable[[Event], Awaitable[Any]],
) -> Callable[[Event], Awaitable[Any]]: ) -> Callable[[Event], Awaitable[Any]]:
add_event_handler(self, handler, event_cls, filter) add_event_handler(self, handler, event_cls, filter)
return handler return handler
@ -145,7 +146,7 @@ async def dispatch_next(client: Client) -> None:
for event_cls, handlers in client._handlers.items(): for event_cls, handlers in client._handlers.items():
if event := event_cls._try_from_update(client, update, chat_map): if event := event_cls._try_from_update(client, update, chat_map):
for handler, filter in handlers: for handler, filter in handlers:
if not filter or filter(event): if not filter or (await r if isawaitable(r := filter(event)) else r):
ret = await handler(event) ret = await handler(event)
if not (ret is Continue or client._check_all_handlers): if not (ret is Continue or client._check_all_handlers):
return return

View File

@ -1,11 +1,11 @@
import abc import abc
import typing import typing
from collections.abc import Callable from collections.abc import Callable
from typing import TypeAlias from typing import Awaitable, TypeAlias
from ..event import Event from ..event import Event
Filter: TypeAlias = Callable[[Event], bool] Filter: TypeAlias = Callable[[Event], bool | Awaitable[bool]]
class Combinable(abc.ABC): class Combinable(abc.ABC):