diff --git a/client/src/telethon/_impl/client/client/client.py b/client/src/telethon/_impl/client/client/client.py index 2d434022..91564209 100644 --- a/client/src/telethon/_impl/client/client/client.py +++ b/client/src/telethon/_impl/client/client/client.py @@ -10,6 +10,7 @@ from typing_extensions import Self from ....version import __version__ as default_version from ...mtsender import Connector, Sender +from ...mtsender.reconnection import NoReconnect, ReconnectionPolicy from ...session import ( ChannelRef, ChatHashCache, @@ -215,6 +216,7 @@ class Client: lang_code: Optional[str] = None, datacenter: Optional[DataCenter] = None, connector: Optional[Connector] = None, + reconnection_policy: Optional[ReconnectionPolicy] = None, ) -> None: assert __package__ base_logger = logger or logging.getLogger(__package__[: __package__.index(".")]) @@ -246,6 +248,7 @@ class Client: update_queue_limit=update_queue_limit, base_logger=base_logger, connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)), + reconnection_policy=reconnection_policy or NoReconnect(), ) self._session = Session() @@ -253,9 +256,9 @@ class Client: self._message_box = MessageBox(base_logger=base_logger) self._chat_hashes = ChatHashCache(None) self._last_update_limit_warn: Optional[float] = None - self._updates: asyncio.Queue[ - tuple[abcs.Update, dict[int, Peer]] - ] = asyncio.Queue(maxsize=self._config.update_queue_limit or 0) + self._updates: asyncio.Queue[tuple[abcs.Update, dict[int, Peer]]] = ( + asyncio.Queue(maxsize=self._config.update_queue_limit or 0) + ) self._dispatcher: Optional[asyncio.Task[None]] = None self._handlers: dict[ Type[Event], diff --git a/client/src/telethon/_impl/client/client/net.py b/client/src/telethon/_impl/client/client/net.py index 808c4b6b..e790457d 100644 --- a/client/src/telethon/_impl/client/client/net.py +++ b/client/src/telethon/_impl/client/client/net.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Optional, TypeVar from ....version import __version__ from ...mtproto import BadStatusError, Full, RpcError -from ...mtsender import Connector, Sender +from ...mtsender import Connector, ReconnectionPolicy, Sender from ...mtsender import connect as do_connect_sender from ...session import DataCenter from ...session import User as SessionUser @@ -46,6 +46,7 @@ class Config: api_hash: str base_logger: logging.Logger connector: Connector + reconnection_policy: ReconnectionPolicy device_model: str = field(default_factory=default_device_model) system_version: str = field(default_factory=default_system_version) app_version: str = __version__ @@ -100,6 +101,7 @@ async def connect_sender( auth_key=auth, base_logger=config.base_logger, connector=config.connector, + reconnection_policy=config.reconnection_policy, ) try: diff --git a/client/src/telethon/_impl/crypto/crypto.py b/client/src/telethon/_impl/crypto/crypto.py index 17fd788e..a5a60a38 100644 --- a/client/src/telethon/_impl/crypto/crypto.py +++ b/client/src/telethon/_impl/crypto/crypto.py @@ -7,6 +7,24 @@ from .aes import ige_decrypt, ige_encrypt from .auth_key import AuthKey +class InvalidBufferError(ValueError): + def __init__(self) -> None: + super().__init__("invalid ciphertext buffer length") + + +class AuthKeyMismatchError(ValueError): + def __init__(self) -> None: + super().__init__("server authkey mismatches with ours") + + +class MsgKeyMismatchError(ValueError): + def __init__(self) -> None: + super().__init__("server msgkey mismatches with ours") + + +CryptoError = InvalidBufferError | AuthKeyMismatchError | MsgKeyMismatchError + + # "where x = 0 for messages from client to server and x = 8 for those from server to client" class Side(IntEnum): CLIENT = 0 @@ -77,14 +95,14 @@ def decrypt_data_v2( x = int(side) if len(ciphertext) < 24 or (len(ciphertext) - 24) % 16 != 0: - raise ValueError("invalid ciphertext buffer length") + raise InvalidBufferError() # salt, session_id and sequence_number should also be checked. # However, not doing so has worked fine for years. key_id = ciphertext[:8] if auth_key.key_id != key_id: - raise ValueError("server authkey mismatches with ours") + raise AuthKeyMismatchError() msg_key = ciphertext[8:24] key, iv = calc_key(auth_key, msg_key, side) @@ -93,7 +111,7 @@ def decrypt_data_v2( # https://core.telegram.org/mtproto/security_guidelines#mtproto-encrypted-messages our_key = sha256(auth_key.data[x + 88 : x + 120] + plaintext).digest() if msg_key != our_key[8:24]: - raise ValueError("server msgkey mismatches with ours") + raise MsgKeyMismatchError() return plaintext diff --git a/client/src/telethon/_impl/mtproto/mtp/encrypted.py b/client/src/telethon/_impl/mtproto/mtp/encrypted.py index abcb7015..a9290fe8 100644 --- a/client/src/telethon/_impl/mtproto/mtp/encrypted.py +++ b/client/src/telethon/_impl/mtproto/mtp/encrypted.py @@ -62,11 +62,15 @@ from ..utils import ( ) from .types import ( BadMessageError, + DecompressionFailed, Deserialization, + DeserializationFailure, + MsgBufferTooSmall, MsgId, Mtp, RpcError, RpcResult, + UnexpectedConstructor, Update, ) @@ -85,6 +89,7 @@ UPDATE_IDS = { AffectedFoundMessages.constructor_id(), AffectedHistory.constructor_id(), AffectedMessages.constructor_id(), + # TODO InvitedUsers } HEADER_LEN = 8 + 8 # salt, client_id @@ -151,7 +156,7 @@ class Encrypted(Mtp): self._last_msg_id: int self._in_pending_ack: list[int] = [] self._msg_count: int - self._reset_session() + self.reset() @property def auth_key(self) -> bytes: @@ -166,13 +171,6 @@ class Encrypted(Mtp): def _adjusted_now(self) -> float: return time.time() + self._time_offset - def _reset_session(self) -> None: - self._client_id = struct.unpack(" int: new_msg_id = int(self._adjusted_now() * 0x100000000) if self._last_msg_id >= new_msg_id: @@ -245,12 +243,38 @@ class Encrypted(Mtp): result = rpc_result.result msg_id = MsgId(req_msg_id) - inner_constructor = struct.unpack_from(" None: + raise RuntimeError("msg_copy should not be used") + def _handle_gzip_packed(self, message: Message) -> None: container = GzipPacked.from_bytes(message.body) inner_body = gzip_decompress(container) @@ -459,3 +492,11 @@ class Encrypted(Mtp): result = self._deserialization[:] self._deserialization.clear() return result + + def reset(self) -> None: + self._client_id = struct.unpack("= 0, got: {length}") + raise NegativeLengthError(got=length) - if 20 + length > len(payload): - raise ValueError( - f"message too short, expected: {20 + length}, got {len(payload)}" - ) + if 20 + length > (lp := len(payload)): + raise TooLongMsgError(got=length, max_length=lp - 20) return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))] + + def reset(self) -> None: + self._buffer.clear() diff --git a/client/src/telethon/_impl/mtproto/mtp/types.py b/client/src/telethon/_impl/mtproto/mtp/types.py index 0f624c7c..4d4b43bc 100644 --- a/client/src/telethon/_impl/mtproto/mtp/types.py +++ b/client/src/telethon/_impl/mtproto/mtp/types.py @@ -5,6 +5,7 @@ from typing import NewType, Optional from typing_extensions import Self +from ...crypto.crypto import CryptoError from ...tl.mtproto.types import RpcError as GeneratedRpcError MsgId = NewType("MsgId", int) @@ -180,7 +181,112 @@ class BadMessageError(ValueError): return self._code == other._code -Deserialization = Update | RpcResult | RpcError | BadMessageError +# Deserialization errors are not fatal, so we don't subclass RpcError. +class BadAuthKeyError(ValueError): + def __init__(self, *args: object, got: int, expected: int) -> None: + super().__init__(f"bad server auth key (got {got}, expected {expected})", *args) + self._got = got + self._expected = expected + + @property + def got(self): + return self._got + + @property + def expected(self): + return self._expected + + +class BadMsgIdError(ValueError): + def __init__(self, *args: object, got: int) -> None: + super().__init__(f"bad server message id (got {got})", *args) + self._got = got + + @property + def got(self): + return self._got + + +class NegativeLengthError(ValueError): + def __init__(self, *args: object, got: int) -> None: + super().__init__(f"bad server message length (got {got})", *args) + self._got = got + + @property + def got(self): + return self._got + + +class TooLongMsgError(ValueError): + def __init__(self, *args: object, got: int, max_length: int) -> None: + super().__init__( + f"bad server message length (got {got}, when at most it should be {max_length})", + *args, + ) + self._got = got + self._expected = max_length + + @property + def got(self): + return self._got + + @property + def expected(self): + return self._expected + + +class MsgBufferTooSmall(ValueError): + def __init__(self, *args: object) -> None: + super().__init__( + "server responded with a payload that's too small to fit a valid message", + *args, + ) + + +class DecompressionFailed(ValueError): + def __init__(self, *args: object) -> None: + super().__init__("failed to decompress server's data", *args) + + +class UnexpectedConstructor(ValueError): + def __init__(self, *args: object, id: int) -> None: + super().__init__(f"unexpected constructor: {id:08x}", *args) + + +class DecryptionError(ValueError): + def __init__(self, *args: object, error: CryptoError) -> None: + super().__init__(f"failed to decrypt message: {error}", *args) + + self._error = error + + @property + def error(self): + return self._error + + +DeserializationError = ( + BadAuthKeyError + | BadMsgIdError + | NegativeLengthError + | TooLongMsgError + | MsgBufferTooSmall + | DecompressionFailed + | UnexpectedConstructor + | DecryptionError +) + + +class DeserializationFailure: + __slots__ = ("msg_id", "error") + + def __init__(self, msg_id: MsgId, error: DeserializationError) -> None: + self.msg_id = msg_id + self.error = error + + +Deserialization = ( + Update | RpcResult | RpcError | BadMessageError | DeserializationFailure +) # https://core.telegram.org/mtproto/description @@ -209,3 +315,9 @@ class Mtp(ABC): """ Deserialize incoming buffer payload. """ + + @abstractmethod + def reset(self) -> None: + """ + Reset the internal buffer. + """ diff --git a/client/src/telethon/_impl/mtproto/transport/abcs.py b/client/src/telethon/_impl/mtproto/transport/abcs.py index d1b01956..5e319655 100644 --- a/client/src/telethon/_impl/mtproto/transport/abcs.py +++ b/client/src/telethon/_impl/mtproto/transport/abcs.py @@ -15,13 +15,37 @@ class Transport(ABC): def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int: pass + @abstractmethod + def reset(self): + pass + class MissingBytesError(ValueError): def __init__(self, *, expected: int, got: int) -> None: super().__init__(f"missing bytes, expected: {expected}, got: {got}") +class BadLenError(ValueError): + def __init__(self, *, got: int) -> None: + super().__init__(f"bad len (got {got})") + + +class BadSeqError(ValueError): + def __init__(self, *, expected: int, got: int) -> None: + super().__init__(f"bad seq (expected {expected}, got {got})") + + +class BadCrcError(ValueError): + def __init__(self, *, expected: int, got: int) -> None: + super().__init__(f"bad crc (expected {expected}, got {got})") + + class BadStatusError(ValueError): def __init__(self, *, status: int) -> None: - super().__init__(f"transport reported bad status: {status}") + super().__init__(f"bad status (negative length -{status})") self.status = status + + +TransportError = ( + MissingBytesError | BadLenError | BadSeqError | BadCrcError | BadStatusError +) diff --git a/client/src/telethon/_impl/mtproto/transport/abridged.py b/client/src/telethon/_impl/mtproto/transport/abridged.py index 1f8cf482..2cd4fc9d 100644 --- a/client/src/telethon/_impl/mtproto/transport/abridged.py +++ b/client/src/telethon/_impl/mtproto/transport/abridged.py @@ -1,3 +1,4 @@ +import logging import struct from .abcs import BadStatusError, MissingBytesError, OutFn, Transport @@ -60,3 +61,7 @@ class Abridged(Transport): output += memoryview(input)[header_len : header_len + length] return header_len + length + + def reset(self): + logging.info("resetting sending of header in abridged transport") + self._init = False diff --git a/client/src/telethon/_impl/mtproto/transport/full.py b/client/src/telethon/_impl/mtproto/transport/full.py index 59cc1e2c..c8a24ed1 100644 --- a/client/src/telethon/_impl/mtproto/transport/full.py +++ b/client/src/telethon/_impl/mtproto/transport/full.py @@ -1,3 +1,4 @@ +import logging import struct from zlib import crc32 @@ -61,3 +62,8 @@ class Full(Transport): self._recv_seq += 1 output += memoryview(input)[8 : length - 4] return length + + def reset(self): + logging.info("resetting recv and send seqs in full transport") + self._send_seq = 0 + self._recv_seq = 0 diff --git a/client/src/telethon/_impl/mtproto/transport/intermediate.py b/client/src/telethon/_impl/mtproto/transport/intermediate.py index 2f5b434e..a533ab9a 100644 --- a/client/src/telethon/_impl/mtproto/transport/intermediate.py +++ b/client/src/telethon/_impl/mtproto/transport/intermediate.py @@ -1,3 +1,4 @@ +import logging import struct from .abcs import BadStatusError, MissingBytesError, OutFn, Transport @@ -52,3 +53,7 @@ class Intermediate(Transport): output += memoryview(input)[4 : 4 + length] return length + 4 + + def reset(self): + logging.info("resetting sending of header in intermediate transport") + self._init = False diff --git a/client/src/telethon/_impl/mtsender/__init__.py b/client/src/telethon/_impl/mtsender/__init__.py index bc9d723f..aaf1fc6d 100644 --- a/client/src/telethon/_impl/mtsender/__init__.py +++ b/client/src/telethon/_impl/mtsender/__init__.py @@ -1,3 +1,4 @@ +from .reconnection import ReconnectionPolicy from .sender import ( MAXIMUM_DATA, NO_PING_DISCONNECT, @@ -18,4 +19,5 @@ __all__ = [ "Connector", "Sender", "connect", + "ReconnectionPolicy", ] diff --git a/client/src/telethon/_impl/mtsender/errors.py b/client/src/telethon/_impl/mtsender/errors.py new file mode 100644 index 00000000..26c94f06 --- /dev/null +++ b/client/src/telethon/_impl/mtsender/errors.py @@ -0,0 +1,6 @@ +from struct import error as struct_error + +from ..mtproto.mtp.types import DeserializationError +from ..mtproto.transport.abcs import TransportError + +ReadError = struct_error | TransportError | DeserializationError diff --git a/client/src/telethon/_impl/mtsender/reconnection.py b/client/src/telethon/_impl/mtsender/reconnection.py new file mode 100644 index 00000000..363aad6d --- /dev/null +++ b/client/src/telethon/_impl/mtsender/reconnection.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod + + +class ReconnectionPolicy(ABC): + """ + Base class for reconnection policies. + + This class defines the interface for reconnection policies used by the MTSender. + It allows for custom reconnection strategies to be implemented by subclasses. + """ + + @abstractmethod + def should_retry(self, attempts: int) -> bool | float: + """ + Determines whether the client should retry the connection attempt. + """ + pass + + +class NoReconnect(ReconnectionPolicy): + def should_retry(self, attempts: int) -> bool: + return False + + +class FixedReconnect(ReconnectionPolicy): + __slots__ = ("max_attempts", "delay") + + def __init__(self, attempts: int, delay: float): + self.max_attempts = attempts + self.delay = delay + + def should_retry(self, attempts: int) -> bool | float: + if attempts < self.max_attempts: + return self.delay + + return False diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 1a46f6b9..f9b8c1f5 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -24,12 +24,14 @@ from ..mtproto import ( Update, authentication, ) +from ..mtproto.mtp.types import DeserializationFailure from ..tl import Request as RemoteCall from ..tl.abcs import Updates from ..tl.core import Serializable from ..tl.mtproto.functions import ping_delay_disconnect -from ..tl.types import UpdateDeleteMessages, UpdateShort +from ..tl.types import UpdateDeleteMessages, UpdateShort, UpdatesTooLong from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages +from .reconnection import ReconnectionPolicy MAXIMUM_DATA = (1024 * 1024) + (8 * 1024) @@ -162,6 +164,8 @@ class Request(Generic[Return]): class Sender: dc_id: int addr: str + _connector: Connector + _reconnection_policy: ReconnectionPolicy _logger: logging.Logger _reader: AsyncReader _writer: AsyncWriter @@ -175,6 +179,7 @@ class Sender: _requests: list[Request[object]] _next_ping: float _read_buffer: bytearray + _write_drain_pending: bool @classmethod async def connect( @@ -185,6 +190,7 @@ class Sender: addr: str, *, connector: Connector, + reconnection_policy: ReconnectionPolicy, base_logger: logging.Logger, ) -> Self: ip, port = addr.split(":") @@ -193,6 +199,8 @@ class Sender: return cls( dc_id=dc_id, addr=addr, + _connector=connector, + _reconnection_policy=reconnection_policy, _logger=base_logger.getChild("mtsender"), _reader=reader, _writer=writer, @@ -206,6 +214,7 @@ class Sender: _requests=[], _next_ping=asyncio.get_running_loop().time() + PING_DELAY, _read_buffer=bytearray(), + _write_drain_pending=False, ) async def disconnect(self) -> None: @@ -236,15 +245,21 @@ class Sender: if rx.done(): return rx.result() - async def step(self) -> None: + async def step(self): + try: + await self._step() + except Exception as error: + await self._on_error(error) + + async def _step(self) -> None: if not self._writing: self._writing = True - await self._do_write() + await self._do_send() self._writing = False if not self._reading: self._reading = True - await self._do_read() + await self._do_recv() self._reading = False else: await self._step_done.wait() @@ -254,7 +269,7 @@ class Sender: self._updates.clear() return updates - async def _do_read(self) -> None: + async def _do_recv(self) -> None: self._step_done.clear() timeout = self._next_ping - asyncio.get_running_loop().time() @@ -266,10 +281,46 @@ class Sender: else: self._on_net_read(recv_data) finally: - self._try_timeout_ping() + self._try_ping_timeout() self._step_done.set() - async def _do_write(self) -> None: + async def _do_send(self) -> None: + self._try_fill_write() + + if self._write_drain_pending: + await self._writer.drain() + self._on_net_write() + + async def _try_connect(self): + attempts = 0 + + ip, port = self.addr.split(":") + + while True: + try: + self._reader, self._writer = await self._connector(ip, int(port)) + self._logger.info( + f"auto-reconnect success after {attempts} failed attempt(s)" + ) + return + except Exception as e: + attempts += 1 + self._logger.warning(f"auto-reconnect failed {attempts} time(s): {e!r}") + await asyncio.sleep(1) + + delay = self._reconnection_policy.should_retry(attempts) + + if delay: + if delay is not True: + await asyncio.sleep(delay) + continue + else: + self._logger.error( + f"auto-reconnect failed {attempts} time(s); giving up" + ) + raise + + def _try_fill_write(self) -> None: if not self._requests: return @@ -283,15 +334,14 @@ class Sender: result = self._mtp.finalize() if result: container_msg_id, mtp_buffer = result - - self._transport.pack(mtp_buffer, self._writer.write) - await self._writer.drain() - for request in self._requests: if isinstance(request.state, Serialized): - request.state = Sent(request.state.msg_id, container_msg_id) + request.state.container_msg_id = container_msg_id - def _try_timeout_ping(self) -> None: + self._transport.pack(mtp_buffer, self._writer.write) + self._write_drain_pending = True + + def _try_ping_timeout(self) -> None: current_time = asyncio.get_running_loop().time() if current_time >= self._next_ping: @@ -321,6 +371,45 @@ class Sender: del self._read_buffer[:n] self._process_mtp_buffer() + def _on_net_write(self) -> None: + for req in self._requests: + if isinstance(req.state, Serialized): + req.state = Sent(req.state.msg_id, req.state.container_msg_id) + + async def _on_error(self, error: Exception) -> None: + self._logger.info(f"handling error: {error}") + self._transport.reset() + self._mtp.reset() + self._logger.info( + "resetting sender state from read_buffer {}, mtp_buffer {}".format( + len(self._read_buffer), + len(self._mtp_buffer), + ) + ) + self._read_buffer.clear() + self._mtp_buffer.clear() + + if isinstance(error, struct.error) and self._reconnection_policy.should_retry( + 0 + ): + self._logger.info(f"read error occurred: {error}") + await self._try_connect() + + for req in self._requests: + req.state = NotSerialized() + + self._updates.append(UpdatesTooLong()) + return + + self._logger.warning( + f"marking all {len(self._requests)} request(s) as failed: {error}" + ) + + for req in self._requests: + req.result.set_exception(error) + + raise error + def _process_mtp_buffer(self) -> None: results = self._mtp.deserialize(self._mtp_buffer) @@ -331,8 +420,14 @@ class Sender: self._process_result(result) elif isinstance(result, RpcError): self._process_error(result) - else: + elif isinstance(result, BadMessageError): self._process_bad_message(result) + elif isinstance(result, DeserializationFailure): + self._process_deserialize_error(result) + else: + raise RuntimeError( + f"unexpected result type {type(result).__name__}: {result}" + ) def _process_update(self, update: bytes | bytearray | memoryview) -> None: try: @@ -421,6 +516,17 @@ class Sender: result._caused_by = struct.unpack_from(" Optional[Request[object]]: for i, req in enumerate(self._requests): if isinstance(req.state, Serialized) and req.state.msg_id == msg_id: @@ -459,6 +565,7 @@ async def connect( auth_key: Optional[bytes], base_logger: logging.Logger, connector: Connector, + reconnection_policy: ReconnectionPolicy, ) -> Sender: if auth_key is None: sender = await Sender.connect( @@ -467,6 +574,7 @@ async def connect( dc_id, addr, connector=connector, + reconnection_policy=reconnection_policy, base_logger=base_logger, ) return await generate_auth_key(sender) @@ -477,6 +585,7 @@ async def connect( dc_id, addr, connector=connector, + reconnection_policy=reconnection_policy, base_logger=base_logger, )