mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-06-23 13:36:39 +00:00
Add reconnection policy support to Sender and related classes
This commit is contained in:
parent
8ef6854516
commit
fdf2a05e3e
@ -246,6 +246,7 @@ class Client:
|
|||||||
update_queue_limit=update_queue_limit,
|
update_queue_limit=update_queue_limit,
|
||||||
base_logger=base_logger,
|
base_logger=base_logger,
|
||||||
connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)),
|
connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)),
|
||||||
|
reconnection_policy=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._session = Session()
|
self._session = Session()
|
||||||
@ -253,9 +254,9 @@ class Client:
|
|||||||
self._message_box = MessageBox(base_logger=base_logger)
|
self._message_box = MessageBox(base_logger=base_logger)
|
||||||
self._chat_hashes = ChatHashCache(None)
|
self._chat_hashes = ChatHashCache(None)
|
||||||
self._last_update_limit_warn: Optional[float] = None
|
self._last_update_limit_warn: Optional[float] = None
|
||||||
self._updates: asyncio.Queue[
|
self._updates: asyncio.Queue[tuple[abcs.Update, dict[int, Peer]]] = (
|
||||||
tuple[abcs.Update, dict[int, Peer]]
|
asyncio.Queue(maxsize=self._config.update_queue_limit or 0)
|
||||||
] = asyncio.Queue(maxsize=self._config.update_queue_limit or 0)
|
)
|
||||||
self._dispatcher: Optional[asyncio.Task[None]] = None
|
self._dispatcher: Optional[asyncio.Task[None]] = None
|
||||||
self._handlers: dict[
|
self._handlers: dict[
|
||||||
Type[Event],
|
Type[Event],
|
||||||
|
@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Optional, TypeVar
|
|||||||
|
|
||||||
from ....version import __version__
|
from ....version import __version__
|
||||||
from ...mtproto import BadStatusError, Full, RpcError
|
from ...mtproto import BadStatusError, Full, RpcError
|
||||||
from ...mtsender import Connector, Sender
|
from ...mtsender import Connector, ReconnectionPolicy, Sender
|
||||||
from ...mtsender import connect as do_connect_sender
|
from ...mtsender import connect as do_connect_sender
|
||||||
from ...session import DataCenter
|
from ...session import DataCenter
|
||||||
from ...session import User as SessionUser
|
from ...session import User as SessionUser
|
||||||
@ -55,6 +55,7 @@ class Config:
|
|||||||
datacenter: Optional[DataCenter] = None
|
datacenter: Optional[DataCenter] = None
|
||||||
flood_sleep_threshold: int = 60
|
flood_sleep_threshold: int = 60
|
||||||
update_queue_limit: Optional[int] = None
|
update_queue_limit: Optional[int] = None
|
||||||
|
reconnection_policy: Optional[ReconnectionPolicy] = None
|
||||||
|
|
||||||
|
|
||||||
KNOWN_DCS = [
|
KNOWN_DCS = [
|
||||||
@ -100,6 +101,7 @@ async def connect_sender(
|
|||||||
auth_key=auth,
|
auth_key=auth,
|
||||||
base_logger=config.base_logger,
|
base_logger=config.base_logger,
|
||||||
connector=config.connector,
|
connector=config.connector,
|
||||||
|
reconnection_policy=config.reconnection_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from .reconnection import ReconnectionPolicy
|
||||||
from .sender import (
|
from .sender import (
|
||||||
MAXIMUM_DATA,
|
MAXIMUM_DATA,
|
||||||
NO_PING_DISCONNECT,
|
NO_PING_DISCONNECT,
|
||||||
@ -18,4 +19,5 @@ __all__ = [
|
|||||||
"Connector",
|
"Connector",
|
||||||
"Sender",
|
"Sender",
|
||||||
"connect",
|
"connect",
|
||||||
|
"ReconnectionPolicy",
|
||||||
]
|
]
|
||||||
|
38
client/src/telethon/_impl/mtsender/reconnection.py
Normal file
38
client/src/telethon/_impl/mtsender/reconnection.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class ReconnectionPolicy(ABC):
|
||||||
|
"""
|
||||||
|
Base class for reconnection policies.
|
||||||
|
|
||||||
|
This class defines the interface for reconnection policies used by the MTSender.
|
||||||
|
It allows for custom reconnection strategies to be implemented by subclasses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def should_retry(self, attempts: int) -> bool:
|
||||||
|
"""
|
||||||
|
Determines whether the client should retry the connection attempt.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NoReconnect(ReconnectionPolicy):
|
||||||
|
def should_retry(self, attempts: int) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class FixedReconnect(ReconnectionPolicy):
|
||||||
|
__slots__ = ("max_attempts", "delay")
|
||||||
|
|
||||||
|
def __init__(self, attempts: int, delay: float):
|
||||||
|
self.max_attempts = attempts
|
||||||
|
self.delay = delay
|
||||||
|
|
||||||
|
def should_retry(self, attempts: int) -> bool:
|
||||||
|
if attempts < self.max_attempts:
|
||||||
|
time.sleep(self.delay)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
@ -31,6 +31,7 @@ from ..tl.core import Serializable
|
|||||||
from ..tl.mtproto.functions import ping_delay_disconnect
|
from ..tl.mtproto.functions import ping_delay_disconnect
|
||||||
from ..tl.types import UpdateDeleteMessages, UpdateShort
|
from ..tl.types import UpdateDeleteMessages, UpdateShort
|
||||||
from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages
|
from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages
|
||||||
|
from .reconnection import ReconnectionPolicy
|
||||||
|
|
||||||
MAXIMUM_DATA = (1024 * 1024) + (8 * 1024)
|
MAXIMUM_DATA = (1024 * 1024) + (8 * 1024)
|
||||||
|
|
||||||
@ -164,6 +165,7 @@ class Sender:
|
|||||||
dc_id: int
|
dc_id: int
|
||||||
addr: str
|
addr: str
|
||||||
_connector: Connector
|
_connector: Connector
|
||||||
|
_reconnection_policy: Optional[ReconnectionPolicy]
|
||||||
_logger: logging.Logger
|
_logger: logging.Logger
|
||||||
_reader: AsyncReader
|
_reader: AsyncReader
|
||||||
_writer: AsyncWriter
|
_writer: AsyncWriter
|
||||||
@ -188,6 +190,7 @@ class Sender:
|
|||||||
addr: str,
|
addr: str,
|
||||||
*,
|
*,
|
||||||
connector: Connector,
|
connector: Connector,
|
||||||
|
reconnection_policy: Optional[ReconnectionPolicy],
|
||||||
base_logger: logging.Logger,
|
base_logger: logging.Logger,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
ip, port = addr.split(":")
|
ip, port = addr.split(":")
|
||||||
@ -197,6 +200,7 @@ class Sender:
|
|||||||
dc_id=dc_id,
|
dc_id=dc_id,
|
||||||
addr=addr,
|
addr=addr,
|
||||||
_connector=connector,
|
_connector=connector,
|
||||||
|
_reconnection_policy=reconnection_policy,
|
||||||
_logger=base_logger.getChild("mtsender"),
|
_logger=base_logger.getChild("mtsender"),
|
||||||
_reader=reader,
|
_reader=reader,
|
||||||
_writer=writer,
|
_writer=writer,
|
||||||
@ -536,6 +540,7 @@ async def connect(
|
|||||||
auth_key: Optional[bytes],
|
auth_key: Optional[bytes],
|
||||||
base_logger: logging.Logger,
|
base_logger: logging.Logger,
|
||||||
connector: Connector,
|
connector: Connector,
|
||||||
|
reconnection_policy: Optional[ReconnectionPolicy] = None,
|
||||||
) -> Sender:
|
) -> Sender:
|
||||||
if auth_key is None:
|
if auth_key is None:
|
||||||
sender = await Sender.connect(
|
sender = await Sender.connect(
|
||||||
@ -544,6 +549,7 @@ async def connect(
|
|||||||
dc_id,
|
dc_id,
|
||||||
addr,
|
addr,
|
||||||
connector=connector,
|
connector=connector,
|
||||||
|
reconnection_policy=reconnection_policy,
|
||||||
base_logger=base_logger,
|
base_logger=base_logger,
|
||||||
)
|
)
|
||||||
return await generate_auth_key(sender)
|
return await generate_auth_key(sender)
|
||||||
@ -554,6 +560,7 @@ async def connect(
|
|||||||
dc_id,
|
dc_id,
|
||||||
addr,
|
addr,
|
||||||
connector=connector,
|
connector=connector,
|
||||||
|
reconnection_policy=reconnection_policy,
|
||||||
base_logger=base_logger,
|
base_logger=base_logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user