Add reconnection policy support to Sender and related classes

This commit is contained in:
Jahongir Qurbonov 2025-06-02 15:32:36 +05:00
parent 8ef6854516
commit fdf2a05e3e
No known key found for this signature in database
GPG Key ID: 256976CED13D5F2D
5 changed files with 54 additions and 4 deletions

View File

@ -246,6 +246,7 @@ class Client:
update_queue_limit=update_queue_limit, update_queue_limit=update_queue_limit,
base_logger=base_logger, base_logger=base_logger,
connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)), connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)),
reconnection_policy=None,
) )
self._session = Session() self._session = Session()
@ -253,9 +254,9 @@ class Client:
self._message_box = MessageBox(base_logger=base_logger) self._message_box = MessageBox(base_logger=base_logger)
self._chat_hashes = ChatHashCache(None) self._chat_hashes = ChatHashCache(None)
self._last_update_limit_warn: Optional[float] = None self._last_update_limit_warn: Optional[float] = None
self._updates: asyncio.Queue[ self._updates: asyncio.Queue[tuple[abcs.Update, dict[int, Peer]]] = (
tuple[abcs.Update, dict[int, Peer]] asyncio.Queue(maxsize=self._config.update_queue_limit or 0)
] = asyncio.Queue(maxsize=self._config.update_queue_limit or 0) )
self._dispatcher: Optional[asyncio.Task[None]] = None self._dispatcher: Optional[asyncio.Task[None]] = None
self._handlers: dict[ self._handlers: dict[
Type[Event], Type[Event],

View File

@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Optional, TypeVar
from ....version import __version__ from ....version import __version__
from ...mtproto import BadStatusError, Full, RpcError from ...mtproto import BadStatusError, Full, RpcError
from ...mtsender import Connector, Sender from ...mtsender import Connector, ReconnectionPolicy, Sender
from ...mtsender import connect as do_connect_sender from ...mtsender import connect as do_connect_sender
from ...session import DataCenter from ...session import DataCenter
from ...session import User as SessionUser from ...session import User as SessionUser
@ -55,6 +55,7 @@ class Config:
datacenter: Optional[DataCenter] = None datacenter: Optional[DataCenter] = None
flood_sleep_threshold: int = 60 flood_sleep_threshold: int = 60
update_queue_limit: Optional[int] = None update_queue_limit: Optional[int] = None
reconnection_policy: Optional[ReconnectionPolicy] = None
KNOWN_DCS = [ KNOWN_DCS = [
@ -100,6 +101,7 @@ async def connect_sender(
auth_key=auth, auth_key=auth,
base_logger=config.base_logger, base_logger=config.base_logger,
connector=config.connector, connector=config.connector,
reconnection_policy=config.reconnection_policy,
) )
try: try:

View File

@ -1,3 +1,4 @@
from .reconnection import ReconnectionPolicy
from .sender import ( from .sender import (
MAXIMUM_DATA, MAXIMUM_DATA,
NO_PING_DISCONNECT, NO_PING_DISCONNECT,
@ -18,4 +19,5 @@ __all__ = [
"Connector", "Connector",
"Sender", "Sender",
"connect", "connect",
"ReconnectionPolicy",
] ]

View File

@ -0,0 +1,38 @@
import time
from abc import ABC, abstractmethod
class ReconnectionPolicy(ABC):
"""
Base class for reconnection policies.
This class defines the interface for reconnection policies used by the MTSender.
It allows for custom reconnection strategies to be implemented by subclasses.
"""
@abstractmethod
def should_retry(self, attempts: int) -> bool:
"""
Determines whether the client should retry the connection attempt.
"""
pass
class NoReconnect(ReconnectionPolicy):
def should_retry(self, attempts: int) -> bool:
return False
class FixedReconnect(ReconnectionPolicy):
__slots__ = ("max_attempts", "delay")
def __init__(self, attempts: int, delay: float):
self.max_attempts = attempts
self.delay = delay
def should_retry(self, attempts: int) -> bool:
if attempts < self.max_attempts:
time.sleep(self.delay)
return True
return False

View File

@ -31,6 +31,7 @@ from ..tl.core import Serializable
from ..tl.mtproto.functions import ping_delay_disconnect from ..tl.mtproto.functions import ping_delay_disconnect
from ..tl.types import UpdateDeleteMessages, UpdateShort from ..tl.types import UpdateDeleteMessages, UpdateShort
from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages
from .reconnection import ReconnectionPolicy
MAXIMUM_DATA = (1024 * 1024) + (8 * 1024) MAXIMUM_DATA = (1024 * 1024) + (8 * 1024)
@ -164,6 +165,7 @@ class Sender:
dc_id: int dc_id: int
addr: str addr: str
_connector: Connector _connector: Connector
_reconnection_policy: Optional[ReconnectionPolicy]
_logger: logging.Logger _logger: logging.Logger
_reader: AsyncReader _reader: AsyncReader
_writer: AsyncWriter _writer: AsyncWriter
@ -188,6 +190,7 @@ class Sender:
addr: str, addr: str,
*, *,
connector: Connector, connector: Connector,
reconnection_policy: Optional[ReconnectionPolicy],
base_logger: logging.Logger, base_logger: logging.Logger,
) -> Self: ) -> Self:
ip, port = addr.split(":") ip, port = addr.split(":")
@ -197,6 +200,7 @@ class Sender:
dc_id=dc_id, dc_id=dc_id,
addr=addr, addr=addr,
_connector=connector, _connector=connector,
_reconnection_policy=reconnection_policy,
_logger=base_logger.getChild("mtsender"), _logger=base_logger.getChild("mtsender"),
_reader=reader, _reader=reader,
_writer=writer, _writer=writer,
@ -536,6 +540,7 @@ async def connect(
auth_key: Optional[bytes], auth_key: Optional[bytes],
base_logger: logging.Logger, base_logger: logging.Logger,
connector: Connector, connector: Connector,
reconnection_policy: Optional[ReconnectionPolicy] = None,
) -> Sender: ) -> Sender:
if auth_key is None: if auth_key is None:
sender = await Sender.connect( sender = await Sender.connect(
@ -544,6 +549,7 @@ async def connect(
dc_id, dc_id,
addr, addr,
connector=connector, connector=connector,
reconnection_policy=reconnection_policy,
base_logger=base_logger, base_logger=base_logger,
) )
return await generate_auth_key(sender) return await generate_auth_key(sender)
@ -554,6 +560,7 @@ async def connect(
dc_id, dc_id,
addr, addr,
connector=connector, connector=connector,
reconnection_policy=reconnection_policy,
base_logger=base_logger, base_logger=base_logger,
) )