diff --git a/client/src/telethon/_impl/client/client/client.py b/client/src/telethon/_impl/client/client/client.py index 2d434022..8fec6a95 100644 --- a/client/src/telethon/_impl/client/client/client.py +++ b/client/src/telethon/_impl/client/client/client.py @@ -246,6 +246,7 @@ class Client: update_queue_limit=update_queue_limit, base_logger=base_logger, connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)), + reconnection_policy=None, ) self._session = Session() @@ -253,9 +254,9 @@ class Client: self._message_box = MessageBox(base_logger=base_logger) self._chat_hashes = ChatHashCache(None) self._last_update_limit_warn: Optional[float] = None - self._updates: asyncio.Queue[ - tuple[abcs.Update, dict[int, Peer]] - ] = asyncio.Queue(maxsize=self._config.update_queue_limit or 0) + self._updates: asyncio.Queue[tuple[abcs.Update, dict[int, Peer]]] = ( + asyncio.Queue(maxsize=self._config.update_queue_limit or 0) + ) self._dispatcher: Optional[asyncio.Task[None]] = None self._handlers: dict[ Type[Event], diff --git a/client/src/telethon/_impl/client/client/net.py b/client/src/telethon/_impl/client/client/net.py index 808c4b6b..2770ed9a 100644 --- a/client/src/telethon/_impl/client/client/net.py +++ b/client/src/telethon/_impl/client/client/net.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Optional, TypeVar from ....version import __version__ from ...mtproto import BadStatusError, Full, RpcError -from ...mtsender import Connector, Sender +from ...mtsender import Connector, ReconnectionPolicy, Sender from ...mtsender import connect as do_connect_sender from ...session import DataCenter from ...session import User as SessionUser @@ -55,6 +55,7 @@ class Config: datacenter: Optional[DataCenter] = None flood_sleep_threshold: int = 60 update_queue_limit: Optional[int] = None + reconnection_policy: Optional[ReconnectionPolicy] = None KNOWN_DCS = [ @@ -100,6 +101,7 @@ async def connect_sender( auth_key=auth, base_logger=config.base_logger, connector=config.connector, + reconnection_policy=config.reconnection_policy, ) try: diff --git a/client/src/telethon/_impl/mtsender/__init__.py b/client/src/telethon/_impl/mtsender/__init__.py index bc9d723f..aaf1fc6d 100644 --- a/client/src/telethon/_impl/mtsender/__init__.py +++ b/client/src/telethon/_impl/mtsender/__init__.py @@ -1,3 +1,4 @@ +from .reconnection import ReconnectionPolicy from .sender import ( MAXIMUM_DATA, NO_PING_DISCONNECT, @@ -18,4 +19,5 @@ __all__ = [ "Connector", "Sender", "connect", + "ReconnectionPolicy", ] diff --git a/client/src/telethon/_impl/mtsender/reconnection.py b/client/src/telethon/_impl/mtsender/reconnection.py new file mode 100644 index 00000000..2fb51ddd --- /dev/null +++ b/client/src/telethon/_impl/mtsender/reconnection.py @@ -0,0 +1,38 @@ +import time +from abc import ABC, abstractmethod + + +class ReconnectionPolicy(ABC): + """ + Base class for reconnection policies. + + This class defines the interface for reconnection policies used by the MTSender. + It allows for custom reconnection strategies to be implemented by subclasses. + """ + + @abstractmethod + def should_retry(self, attempts: int) -> bool: + """ + Determines whether the client should retry the connection attempt. + """ + pass + + +class NoReconnect(ReconnectionPolicy): + def should_retry(self, attempts: int) -> bool: + return False + + +class FixedReconnect(ReconnectionPolicy): + __slots__ = ("max_attempts", "delay") + + def __init__(self, attempts: int, delay: float): + self.max_attempts = attempts + self.delay = delay + + def should_retry(self, attempts: int) -> bool: + if attempts < self.max_attempts: + time.sleep(self.delay) + return True + + return False diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index fee27620..242f4abd 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -31,6 +31,7 @@ from ..tl.core import Serializable from ..tl.mtproto.functions import ping_delay_disconnect from ..tl.types import UpdateDeleteMessages, UpdateShort from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages +from .reconnection import ReconnectionPolicy MAXIMUM_DATA = (1024 * 1024) + (8 * 1024) @@ -164,6 +165,7 @@ class Sender: dc_id: int addr: str _connector: Connector + _reconnection_policy: Optional[ReconnectionPolicy] _logger: logging.Logger _reader: AsyncReader _writer: AsyncWriter @@ -188,6 +190,7 @@ class Sender: addr: str, *, connector: Connector, + reconnection_policy: Optional[ReconnectionPolicy], base_logger: logging.Logger, ) -> Self: ip, port = addr.split(":") @@ -197,6 +200,7 @@ class Sender: dc_id=dc_id, addr=addr, _connector=connector, + _reconnection_policy=reconnection_policy, _logger=base_logger.getChild("mtsender"), _reader=reader, _writer=writer, @@ -536,6 +540,7 @@ async def connect( auth_key: Optional[bytes], base_logger: logging.Logger, connector: Connector, + reconnection_policy: Optional[ReconnectionPolicy] = None, ) -> Sender: if auth_key is None: sender = await Sender.connect( @@ -544,6 +549,7 @@ async def connect( dc_id, addr, connector=connector, + reconnection_policy=reconnection_policy, base_logger=base_logger, ) return await generate_auth_key(sender) @@ -554,6 +560,7 @@ async def connect( dc_id, addr, connector=connector, + reconnection_policy=reconnection_policy, base_logger=base_logger, )