From 59daea32c7af0288770f843cfb58881e80b0f6f1 Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Sat, 31 May 2025 13:30:43 +0500 Subject: [PATCH 01/16] Implement equality check for Request class and optimize request removal in Sender class --- client/src/telethon/_impl/mtsender/sender.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 1a46f6b9..4757ba22 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -157,6 +157,9 @@ class Request(Generic[Return]): state: RequestState result: Future[Return] + def __eq__(self, value: object) -> bool: + return self is value + @dataclass class Sender: @@ -422,18 +425,17 @@ class Sender: req.result.set_exception(result) def _pop_request(self, msg_id: MsgId) -> Optional[Request[object]]: - for i, req in enumerate(self._requests): + for req in self._requests: if isinstance(req.state, Serialized) and req.state.msg_id == msg_id: raise RuntimeError("got response for unsent request") elif isinstance(req.state, Sent) and req.state.msg_id == msg_id: - del self._requests[i] + self._requests.remove(req) return req return None def _drain_requests(self, msg_id: MsgId) -> Iterator[Request[object]]: - for i in reversed(range(len(self._requests))): - req = self._requests[i] + for req in self._requests: if isinstance(req.state, Serialized) and ( req.state.msg_id == msg_id or req.state.container_msg_id == msg_id ): @@ -441,7 +443,8 @@ class Sender: elif isinstance(req.state, Sent) and ( req.state.msg_id == msg_id or req.state.container_msg_id == msg_id ): - yield self._requests.pop(i) + self._requests.remove(req) + yield req @property def auth_key(self) -> Optional[bytes]: From 1c2aafce2a0227a565514ee365a014efff1c17ba Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Sun, 1 Jun 2025 19:55:25 +0500 Subject: [PATCH 02/16] Refactor error handling in crypto and mtproto modules; introduce custom error classes and improve deserialization error processing in Sender class --- client/src/telethon/_impl/crypto/crypto.py | 24 +++- .../telethon/_impl/mtproto/mtp/encrypted.py | 73 +++++++++--- .../src/telethon/_impl/mtproto/mtp/plain.py | 9 +- .../src/telethon/_impl/mtproto/mtp/types.py | 107 +++++++++++++++++- client/src/telethon/_impl/mtsender/errors.py | 0 client/src/telethon/_impl/mtsender/sender.py | 18 ++- 6 files changed, 206 insertions(+), 25 deletions(-) create mode 100644 client/src/telethon/_impl/mtsender/errors.py diff --git a/client/src/telethon/_impl/crypto/crypto.py b/client/src/telethon/_impl/crypto/crypto.py index 17fd788e..1c47962e 100644 --- a/client/src/telethon/_impl/crypto/crypto.py +++ b/client/src/telethon/_impl/crypto/crypto.py @@ -7,6 +7,24 @@ from .aes import ige_decrypt, ige_encrypt from .auth_key import AuthKey +class InvalidBufferError(ValueError): + def __init__(self) -> None: + super().__init__("Invalid ciphertext buffer length") + + +class AuthKeyMismatchError(ValueError): + def __init__(self) -> None: + super().__init__("Server authkey mismatches with ours") + + +class MsgKeyMismatchError(ValueError): + def __init__(self) -> None: + super().__init__("Server msgkey mismatches with ours") + + +CryptoError = InvalidBufferError | AuthKeyMismatchError | MsgKeyMismatchError + + # "where x = 0 for messages from client to server and x = 8 for those from server to client" class Side(IntEnum): CLIENT = 0 @@ -77,14 +95,14 @@ def decrypt_data_v2( x = int(side) if len(ciphertext) < 24 or (len(ciphertext) - 24) % 16 != 0: - raise ValueError("invalid ciphertext buffer length") + raise InvalidBufferError() # salt, session_id and sequence_number should also be checked. # However, not doing so has worked fine for years. key_id = ciphertext[:8] if auth_key.key_id != key_id: - raise ValueError("server authkey mismatches with ours") + raise AuthKeyMismatchError() msg_key = ciphertext[8:24] key, iv = calc_key(auth_key, msg_key, side) @@ -93,7 +111,7 @@ def decrypt_data_v2( # https://core.telegram.org/mtproto/security_guidelines#mtproto-encrypted-messages our_key = sha256(auth_key.data[x + 88 : x + 120] + plaintext).digest() if msg_key != our_key[8:24]: - raise ValueError("server msgkey mismatches with ours") + raise MsgKeyMismatchError() return plaintext diff --git a/client/src/telethon/_impl/mtproto/mtp/encrypted.py b/client/src/telethon/_impl/mtproto/mtp/encrypted.py index abcb7015..1f94bb3b 100644 --- a/client/src/telethon/_impl/mtproto/mtp/encrypted.py +++ b/client/src/telethon/_impl/mtproto/mtp/encrypted.py @@ -62,11 +62,15 @@ from ..utils import ( ) from .types import ( BadMessageError, + DecompressionFailed, Deserialization, + DeserializationFailure, + MsgBufferTooSmall, MsgId, Mtp, RpcError, RpcResult, + UnexpectedConstructor, Update, ) @@ -85,6 +89,7 @@ UPDATE_IDS = { AffectedFoundMessages.constructor_id(), AffectedHistory.constructor_id(), AffectedMessages.constructor_id(), + # TODO InvitedUsers } HEADER_LEN = 8 + 8 # salt, client_id @@ -151,7 +156,7 @@ class Encrypted(Mtp): self._last_msg_id: int self._in_pending_ack: list[int] = [] self._msg_count: int - self._reset_session() + self.reset() @property def auth_key(self) -> bytes: @@ -166,13 +171,6 @@ class Encrypted(Mtp): def _adjusted_now(self) -> float: return time.time() + self._time_offset - def _reset_session(self) -> None: - self._client_id = struct.unpack(" int: new_msg_id = int(self._adjusted_now() * 0x100000000) if self._last_msg_id >= new_msg_id: @@ -245,12 +243,38 @@ class Encrypted(Mtp): result = rpc_result.result msg_id = MsgId(req_msg_id) - inner_constructor = struct.unpack_from(" None: + raise RuntimeError("msg_copy should not be used") + def _handle_gzip_packed(self, message: Message) -> None: container = GzipPacked.from_bytes(message.body) inner_body = gzip_decompress(container) @@ -459,3 +492,11 @@ class Encrypted(Mtp): result = self._deserialization[:] self._deserialization.clear() return result + + def reset(self) -> None: + self._client_id = struct.unpack("= 0, got: {length}") - if 20 + length > len(payload): - raise ValueError( - f"message too short, expected: {20 + length}, got {len(payload)}" - ) + if 20 + length > (lp := len(payload)): + raise ValueError(f"message too short, expected: {20 + length}, got {lp}") return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))] + + def reset(self) -> None: + self._buffer.clear() diff --git a/client/src/telethon/_impl/mtproto/mtp/types.py b/client/src/telethon/_impl/mtproto/mtp/types.py index 0f624c7c..7fe4b209 100644 --- a/client/src/telethon/_impl/mtproto/mtp/types.py +++ b/client/src/telethon/_impl/mtproto/mtp/types.py @@ -5,6 +5,7 @@ from typing import NewType, Optional from typing_extensions import Self +from ...crypto.crypto import CryptoError from ...tl.mtproto.types import RpcError as GeneratedRpcError MsgId = NewType("MsgId", int) @@ -180,7 +181,105 @@ class BadMessageError(ValueError): return self._code == other._code -Deserialization = Update | RpcResult | RpcError | BadMessageError +DeserializationError = ValueError + + +class DeserializationFailure: + __slots__ = ("msg_id", "error") + + def __init__(self, msg_id: MsgId, error: DeserializationError) -> None: + self.msg_id = msg_id + self.error = error + + +Deserialization = ( + Update | RpcResult | RpcError | BadMessageError | DeserializationFailure +) + + +# Deserialization errors are not fatal, so we don't subclass RpcError. +class BadAuthKeyError(DeserializationError): + def __init__(self, *args: object, got: int, expected: int) -> None: + super().__init__(f"Bad server auth key (got {got}, expected {expected})", *args) + self._got = got + self._expected = expected + + @property + def got(self): + return self._got + + @property + def expected(self): + return self._expected + + +class BadMsgIdError(DeserializationError): + def __init__(self, *args: object, got: int) -> None: + super().__init__(f"Bad server message id (got {got})", *args) + self._got = got + + @property + def got(self): + return self._got + + +class NegativeLengthError(DeserializationError): + def __init__(self, *args: object, got: int) -> None: + super().__init__(f"Bad server message length (got {got})", *args) + self._got = got + + @property + def got(self): + return self._got + + +class TooLongMsgError(DeserializationError): + __slots__ = ("expected", "got") + + def __init__(self, *args: object, got: int, max_length: int) -> None: + super().__init__( + f"Bad server message length (got {got}, when at most it should be {max_length})", + *args, + ) + self._got = got + self._expected = max_length + + @property + def got(self): + return self._got + + @property + def expected(self): + return self._expected + + +class MsgBufferTooSmall(DeserializationError): + def __init__(self, *args: object) -> None: + super().__init__( + "Server responded with a payload that's too small to fit a valid message", + *args, + ) + + +class DecompressionFailed(DeserializationError): + def __init__(self, *args: object) -> None: + super().__init__("Failed to decompress server's data", *args) + + +class UnexpectedConstructor(DeserializationError): + def __init__(self, *args: object, id: int) -> None: + super().__init__(f"Unexpected constructor: {id:08x}", *args) + + +class DecryptionError(DeserializationError): + def __init__(self, *args: object, error: CryptoError) -> None: + super().__init__(f"failed to decrypt message: {error}", *args) + + self._error = error + + @property + def error(self): + return self._error # https://core.telegram.org/mtproto/description @@ -209,3 +308,9 @@ class Mtp(ABC): """ Deserialize incoming buffer payload. """ + + @abstractmethod + def reset(self) -> None: + """ + Reset the internal buffer. + """ diff --git a/client/src/telethon/_impl/mtsender/errors.py b/client/src/telethon/_impl/mtsender/errors.py new file mode 100644 index 00000000..e69de29b diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 4757ba22..0361c635 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -24,6 +24,7 @@ from ..mtproto import ( Update, authentication, ) +from ..mtproto.mtp.types import DeserializationFailure from ..tl import Request as RemoteCall from ..tl.abcs import Updates from ..tl.core import Serializable @@ -334,8 +335,12 @@ class Sender: self._process_result(result) elif isinstance(result, RpcError): self._process_error(result) - else: + elif isinstance(result, BadMessageError): self._process_bad_message(result) + elif isinstance(result, DeserializationFailure): + self._process_deserialize_error(result) + else: + raise RuntimeError(f"Unexpected result: {result}") def _process_update(self, update: bytes | bytearray | memoryview) -> None: try: @@ -424,6 +429,17 @@ class Sender: result._caused_by = struct.unpack_from(" Optional[Request[object]]: for req in self._requests: if isinstance(req.state, Serialized) and req.state.msg_id == msg_id: From d590271ebd0ed43f3469c9367a2f72daf6e09513 Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Sun, 1 Jun 2025 20:00:42 +0500 Subject: [PATCH 03/16] Refactor Request class equality check and optimize request removal in Sender class --- client/src/telethon/_impl/mtsender/sender.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 0361c635..469b85db 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -158,9 +158,6 @@ class Request(Generic[Return]): state: RequestState result: Future[Return] - def __eq__(self, value: object) -> bool: - return self is value - @dataclass class Sender: @@ -441,17 +438,19 @@ class Sender: ) def _pop_request(self, msg_id: MsgId) -> Optional[Request[object]]: - for req in self._requests: + for i, req in enumerate(self._requests): if isinstance(req.state, Serialized) and req.state.msg_id == msg_id: raise RuntimeError("got response for unsent request") elif isinstance(req.state, Sent) and req.state.msg_id == msg_id: - self._requests.remove(req) + del self._requests[i] return req return None def _drain_requests(self, msg_id: MsgId) -> Iterator[Request[object]]: - for req in self._requests: + for i in reversed(range(len(self._requests))): + req = self._requests[i] + if isinstance(req.state, Serialized) and ( req.state.msg_id == msg_id or req.state.container_msg_id == msg_id ): @@ -459,8 +458,7 @@ class Sender: elif isinstance(req.state, Sent) and ( req.state.msg_id == msg_id or req.state.container_msg_id == msg_id ): - self._requests.remove(req) - yield req + yield self._requests.pop(i) @property def auth_key(self) -> Optional[bytes]: From 602bb6381affab9fb3a50d057f1768359080ac40 Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Sun, 1 Jun 2025 20:01:25 +0500 Subject: [PATCH 04/16] Remove unused error handling module in mtsender --- client/src/telethon/_impl/mtsender/errors.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 client/src/telethon/_impl/mtsender/errors.py diff --git a/client/src/telethon/_impl/mtsender/errors.py b/client/src/telethon/_impl/mtsender/errors.py deleted file mode 100644 index e69de29b..00000000 From 05af28d5e1b79d8ad99a4a79bbab291cba90ef01 Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Sun, 1 Jun 2025 20:17:29 +0500 Subject: [PATCH 05/16] Implement reset method in transport classes and add logging for state resets --- client/src/telethon/_impl/mtproto/transport/abcs.py | 4 ++++ client/src/telethon/_impl/mtproto/transport/abridged.py | 5 +++++ client/src/telethon/_impl/mtproto/transport/full.py | 6 ++++++ client/src/telethon/_impl/mtproto/transport/intermediate.py | 5 +++++ 4 files changed, 20 insertions(+) diff --git a/client/src/telethon/_impl/mtproto/transport/abcs.py b/client/src/telethon/_impl/mtproto/transport/abcs.py index d1b01956..b3a16f58 100644 --- a/client/src/telethon/_impl/mtproto/transport/abcs.py +++ b/client/src/telethon/_impl/mtproto/transport/abcs.py @@ -15,6 +15,10 @@ class Transport(ABC): def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int: pass + @abstractmethod + def reset(self): + pass + class MissingBytesError(ValueError): def __init__(self, *, expected: int, got: int) -> None: diff --git a/client/src/telethon/_impl/mtproto/transport/abridged.py b/client/src/telethon/_impl/mtproto/transport/abridged.py index 1f8cf482..65b92a14 100644 --- a/client/src/telethon/_impl/mtproto/transport/abridged.py +++ b/client/src/telethon/_impl/mtproto/transport/abridged.py @@ -1,3 +1,4 @@ +import logging import struct from .abcs import BadStatusError, MissingBytesError, OutFn, Transport @@ -60,3 +61,7 @@ class Abridged(Transport): output += memoryview(input)[header_len : header_len + length] return header_len + length + + def reset(self): + logging.info("Resetting sending of header in abridged transport") + self._init = False diff --git a/client/src/telethon/_impl/mtproto/transport/full.py b/client/src/telethon/_impl/mtproto/transport/full.py index 59cc1e2c..f04d751c 100644 --- a/client/src/telethon/_impl/mtproto/transport/full.py +++ b/client/src/telethon/_impl/mtproto/transport/full.py @@ -1,3 +1,4 @@ +import logging import struct from zlib import crc32 @@ -61,3 +62,8 @@ class Full(Transport): self._recv_seq += 1 output += memoryview(input)[8 : length - 4] return length + + def reset(self): + logging.info("Resetting recv and send seqs in full transport") + self._send_seq = 0 + self._recv_seq = 0 diff --git a/client/src/telethon/_impl/mtproto/transport/intermediate.py b/client/src/telethon/_impl/mtproto/transport/intermediate.py index 2f5b434e..e75adff6 100644 --- a/client/src/telethon/_impl/mtproto/transport/intermediate.py +++ b/client/src/telethon/_impl/mtproto/transport/intermediate.py @@ -1,3 +1,4 @@ +import logging import struct from .abcs import BadStatusError, MissingBytesError, OutFn, Transport @@ -52,3 +53,7 @@ class Intermediate(Transport): output += memoryview(input)[4 : 4 + length] return length + 4 + + def reset(self): + logging.info("Resetting sending of header in intermediate transport") + self._init = False From 66e9b537911d1754b2341b5bf0fba34d33bf94d5 Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Sun, 1 Jun 2025 20:17:35 +0500 Subject: [PATCH 06/16] Add error handling in Sender class; reset transport and clear buffers on exception --- client/src/telethon/_impl/mtsender/sender.py | 47 +++++++++++++++----- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 469b85db..ac49a531 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -238,17 +238,20 @@ class Sender: return rx.result() async def step(self) -> None: - if not self._writing: - self._writing = True - await self._do_write() - self._writing = False + try: + if not self._writing: + self._writing = True + await self._do_write() + self._writing = False - if not self._reading: - self._reading = True - await self._do_read() - self._reading = False - else: - await self._step_done.wait() + if not self._reading: + self._reading = True + await self._do_read() + self._reading = False + else: + await self._step_done.wait() + except Exception: + self._ def pop_updates(self) -> list[Updates]: updates = self._updates[:] @@ -322,6 +325,30 @@ class Sender: del self._read_buffer[:n] self._process_mtp_buffer() + def _on_error(self, error: Exception): + logging.info(f"Handling error: {error}") + self._transport.reset() + self._mtp.reset() + logging.info( + "Resetting sender state from read_buffer {}, mtp_buffer {}".format( + len(self._read_buffer), + len(self._mtp_buffer), + ) + ) + self._read_buffer.clear() + self._mtp_buffer.clear() + + # TODO: reset + + logging.warning( + f"marking all {len(self._requests)} request(s) as failed: {error}" + ) + + for req in self._requests: + req.result.set_exception(error) + + raise error + def _process_mtp_buffer(self) -> None: results = self._mtp.deserialize(self._mtp_buffer) From bc43ae47181174b31986ee3c94e76163d1cd99fe Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Sun, 1 Jun 2025 20:27:15 +0500 Subject: [PATCH 07/16] Handle exceptions in Sender class by invoking error handler --- client/src/telethon/_impl/mtsender/sender.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index ac49a531..10ec0382 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -250,8 +250,8 @@ class Sender: self._reading = False else: await self._step_done.wait() - except Exception: - self._ + except Exception as e: + self._on_error(e) def pop_updates(self) -> list[Updates]: updates = self._updates[:] From bf4560a8c17d09c2ca0021faa2d312b8b0f64c09 Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Sun, 1 Jun 2025 20:28:41 +0500 Subject: [PATCH 08/16] Remove unused __slots__ declaration from TooLongMsgError class --- client/src/telethon/_impl/mtproto/mtp/types.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/client/src/telethon/_impl/mtproto/mtp/types.py b/client/src/telethon/_impl/mtproto/mtp/types.py index 7fe4b209..5594b7bb 100644 --- a/client/src/telethon/_impl/mtproto/mtp/types.py +++ b/client/src/telethon/_impl/mtproto/mtp/types.py @@ -234,8 +234,6 @@ class NegativeLengthError(DeserializationError): class TooLongMsgError(DeserializationError): - __slots__ = ("expected", "got") - def __init__(self, *args: object, got: int, max_length: int) -> None: super().__init__( f"Bad server message length (got {got}, when at most it should be {max_length})", From 8bcc3aa8b8d0b47f9d0dedfc089310b2fe9cc257 Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Mon, 2 Jun 2025 00:20:22 +0500 Subject: [PATCH 09/16] Improve error logging in Encrypted class; specify failure reasons for unpacking, deserialization, and decompression --- client/src/telethon/_impl/mtproto/mtp/encrypted.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/client/src/telethon/_impl/mtproto/mtp/encrypted.py b/client/src/telethon/_impl/mtproto/mtp/encrypted.py index 1f94bb3b..78d2faa3 100644 --- a/client/src/telethon/_impl/mtproto/mtp/encrypted.py +++ b/client/src/telethon/_impl/mtproto/mtp/encrypted.py @@ -246,10 +246,10 @@ class Encrypted(Mtp): try: inner_constructor = struct.unpack_from(" Date: Mon, 2 Jun 2025 01:15:48 +0500 Subject: [PATCH 10/16] Enhance Sender class with connection handling and refactor step methods for improved readability and error management --- client/src/telethon/_impl/mtsender/sender.py | 81 ++++++++++++++------ 1 file changed, 57 insertions(+), 24 deletions(-) diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 10ec0382..fee27620 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -163,6 +163,7 @@ class Request(Generic[Return]): class Sender: dc_id: int addr: str + _connector: Connector _logger: logging.Logger _reader: AsyncReader _writer: AsyncWriter @@ -176,6 +177,7 @@ class Sender: _requests: list[Request[object]] _next_ping: float _read_buffer: bytearray + _write_drain_pending: bool @classmethod async def connect( @@ -194,6 +196,7 @@ class Sender: return cls( dc_id=dc_id, addr=addr, + _connector=connector, _logger=base_logger.getChild("mtsender"), _reader=reader, _writer=writer, @@ -207,6 +210,7 @@ class Sender: _requests=[], _next_ping=asyncio.get_running_loop().time() + PING_DELAY, _read_buffer=bytearray(), + _write_drain_pending=False, ) async def disconnect(self) -> None: @@ -237,28 +241,31 @@ class Sender: if rx.done(): return rx.result() - async def step(self) -> None: + async def step(self): try: - if not self._writing: - self._writing = True - await self._do_write() - self._writing = False + await self._step() + except Exception as error: + self._on_error(error) - if not self._reading: - self._reading = True - await self._do_read() - self._reading = False - else: - await self._step_done.wait() - except Exception as e: - self._on_error(e) + async def _step(self) -> None: + if not self._writing: + self._writing = True + await self._do_send() + self._writing = False + + if not self._reading: + self._reading = True + await self._do_recv() + self._reading = False + else: + await self._step_done.wait() def pop_updates(self) -> list[Updates]: updates = self._updates[:] self._updates.clear() return updates - async def _do_read(self) -> None: + async def _do_recv(self) -> None: self._step_done.clear() timeout = self._next_ping - asyncio.get_running_loop().time() @@ -270,10 +277,31 @@ class Sender: else: self._on_net_read(recv_data) finally: - self._try_timeout_ping() + self._try_ping_timeout() self._step_done.set() - async def _do_write(self) -> None: + async def _do_send(self) -> None: + self._try_fill_write() + + if self._write_drain_pending: + await self._writer.drain() + self._on_net_write() + + async def try_connect(self): + # attempts = 0 + + ip, port = self.addr.split(":") + + while True: + try: + self._reader, self._writer = await self._connector(ip, int(port)) + break + except Exception as e: + logging.exception(e) + # TODO: reconnection_policy + break + + def _try_fill_write(self) -> None: if not self._requests: return @@ -287,15 +315,14 @@ class Sender: result = self._mtp.finalize() if result: container_msg_id, mtp_buffer = result - - self._transport.pack(mtp_buffer, self._writer.write) - await self._writer.drain() - for request in self._requests: if isinstance(request.state, Serialized): - request.state = Sent(request.state.msg_id, container_msg_id) + request.state.container_msg_id = container_msg_id - def _try_timeout_ping(self) -> None: + self._transport.pack(mtp_buffer, self._writer.write) + self._write_drain_pending = True + + def _try_ping_timeout(self) -> None: current_time = asyncio.get_running_loop().time() if current_time >= self._next_ping: @@ -325,6 +352,11 @@ class Sender: del self._read_buffer[:n] self._process_mtp_buffer() + def _on_net_write(self) -> None: + for req in self._requests: + if isinstance(req.state, Serialized): + req.state = Sent(req.state.msg_id, req.state.container_msg_id) + def _on_error(self, error: Exception): logging.info(f"Handling error: {error}") self._transport.reset() @@ -364,7 +396,9 @@ class Sender: elif isinstance(result, DeserializationFailure): self._process_deserialize_error(result) else: - raise RuntimeError(f"Unexpected result: {result}") + raise RuntimeError( + f"Unexpected result type {type(result).__name__!r}: {result}" + ) def _process_update(self, update: bytes | bytearray | memoryview) -> None: try: @@ -477,7 +511,6 @@ class Sender: def _drain_requests(self, msg_id: MsgId) -> Iterator[Request[object]]: for i in reversed(range(len(self._requests))): req = self._requests[i] - if isinstance(req.state, Serialized) and ( req.state.msg_id == msg_id or req.state.container_msg_id == msg_id ): From fdf2a05e3ebf1d7a6fc5a846566ed0202588a637 Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Mon, 2 Jun 2025 15:32:36 +0500 Subject: [PATCH 11/16] Add reconnection policy support to Sender and related classes --- .../telethon/_impl/client/client/client.py | 7 ++-- .../src/telethon/_impl/client/client/net.py | 4 +- .../src/telethon/_impl/mtsender/__init__.py | 2 + .../telethon/_impl/mtsender/reconnection.py | 38 +++++++++++++++++++ client/src/telethon/_impl/mtsender/sender.py | 7 ++++ 5 files changed, 54 insertions(+), 4 deletions(-) create mode 100644 client/src/telethon/_impl/mtsender/reconnection.py diff --git a/client/src/telethon/_impl/client/client/client.py b/client/src/telethon/_impl/client/client/client.py index 2d434022..8fec6a95 100644 --- a/client/src/telethon/_impl/client/client/client.py +++ b/client/src/telethon/_impl/client/client/client.py @@ -246,6 +246,7 @@ class Client: update_queue_limit=update_queue_limit, base_logger=base_logger, connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)), + reconnection_policy=None, ) self._session = Session() @@ -253,9 +254,9 @@ class Client: self._message_box = MessageBox(base_logger=base_logger) self._chat_hashes = ChatHashCache(None) self._last_update_limit_warn: Optional[float] = None - self._updates: asyncio.Queue[ - tuple[abcs.Update, dict[int, Peer]] - ] = asyncio.Queue(maxsize=self._config.update_queue_limit or 0) + self._updates: asyncio.Queue[tuple[abcs.Update, dict[int, Peer]]] = ( + asyncio.Queue(maxsize=self._config.update_queue_limit or 0) + ) self._dispatcher: Optional[asyncio.Task[None]] = None self._handlers: dict[ Type[Event], diff --git a/client/src/telethon/_impl/client/client/net.py b/client/src/telethon/_impl/client/client/net.py index 808c4b6b..2770ed9a 100644 --- a/client/src/telethon/_impl/client/client/net.py +++ b/client/src/telethon/_impl/client/client/net.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Optional, TypeVar from ....version import __version__ from ...mtproto import BadStatusError, Full, RpcError -from ...mtsender import Connector, Sender +from ...mtsender import Connector, ReconnectionPolicy, Sender from ...mtsender import connect as do_connect_sender from ...session import DataCenter from ...session import User as SessionUser @@ -55,6 +55,7 @@ class Config: datacenter: Optional[DataCenter] = None flood_sleep_threshold: int = 60 update_queue_limit: Optional[int] = None + reconnection_policy: Optional[ReconnectionPolicy] = None KNOWN_DCS = [ @@ -100,6 +101,7 @@ async def connect_sender( auth_key=auth, base_logger=config.base_logger, connector=config.connector, + reconnection_policy=config.reconnection_policy, ) try: diff --git a/client/src/telethon/_impl/mtsender/__init__.py b/client/src/telethon/_impl/mtsender/__init__.py index bc9d723f..aaf1fc6d 100644 --- a/client/src/telethon/_impl/mtsender/__init__.py +++ b/client/src/telethon/_impl/mtsender/__init__.py @@ -1,3 +1,4 @@ +from .reconnection import ReconnectionPolicy from .sender import ( MAXIMUM_DATA, NO_PING_DISCONNECT, @@ -18,4 +19,5 @@ __all__ = [ "Connector", "Sender", "connect", + "ReconnectionPolicy", ] diff --git a/client/src/telethon/_impl/mtsender/reconnection.py b/client/src/telethon/_impl/mtsender/reconnection.py new file mode 100644 index 00000000..2fb51ddd --- /dev/null +++ b/client/src/telethon/_impl/mtsender/reconnection.py @@ -0,0 +1,38 @@ +import time +from abc import ABC, abstractmethod + + +class ReconnectionPolicy(ABC): + """ + Base class for reconnection policies. + + This class defines the interface for reconnection policies used by the MTSender. + It allows for custom reconnection strategies to be implemented by subclasses. + """ + + @abstractmethod + def should_retry(self, attempts: int) -> bool: + """ + Determines whether the client should retry the connection attempt. + """ + pass + + +class NoReconnect(ReconnectionPolicy): + def should_retry(self, attempts: int) -> bool: + return False + + +class FixedReconnect(ReconnectionPolicy): + __slots__ = ("max_attempts", "delay") + + def __init__(self, attempts: int, delay: float): + self.max_attempts = attempts + self.delay = delay + + def should_retry(self, attempts: int) -> bool: + if attempts < self.max_attempts: + time.sleep(self.delay) + return True + + return False diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index fee27620..242f4abd 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -31,6 +31,7 @@ from ..tl.core import Serializable from ..tl.mtproto.functions import ping_delay_disconnect from ..tl.types import UpdateDeleteMessages, UpdateShort from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages +from .reconnection import ReconnectionPolicy MAXIMUM_DATA = (1024 * 1024) + (8 * 1024) @@ -164,6 +165,7 @@ class Sender: dc_id: int addr: str _connector: Connector + _reconnection_policy: Optional[ReconnectionPolicy] _logger: logging.Logger _reader: AsyncReader _writer: AsyncWriter @@ -188,6 +190,7 @@ class Sender: addr: str, *, connector: Connector, + reconnection_policy: Optional[ReconnectionPolicy], base_logger: logging.Logger, ) -> Self: ip, port = addr.split(":") @@ -197,6 +200,7 @@ class Sender: dc_id=dc_id, addr=addr, _connector=connector, + _reconnection_policy=reconnection_policy, _logger=base_logger.getChild("mtsender"), _reader=reader, _writer=writer, @@ -536,6 +540,7 @@ async def connect( auth_key: Optional[bytes], base_logger: logging.Logger, connector: Connector, + reconnection_policy: Optional[ReconnectionPolicy] = None, ) -> Sender: if auth_key is None: sender = await Sender.connect( @@ -544,6 +549,7 @@ async def connect( dc_id, addr, connector=connector, + reconnection_policy=reconnection_policy, base_logger=base_logger, ) return await generate_auth_key(sender) @@ -554,6 +560,7 @@ async def connect( dc_id, addr, connector=connector, + reconnection_policy=reconnection_policy, base_logger=base_logger, ) From 9c5a6af608011bb579f1a37a06139d97f9f033a7 Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Mon, 2 Jun 2025 16:34:02 +0500 Subject: [PATCH 12/16] Refactor error messages for consistency and clarity; update reconnection logic in Sender class --- client/src/telethon/_impl/crypto/crypto.py | 6 +- .../telethon/_impl/mtproto/mtp/encrypted.py | 6 +- .../src/telethon/_impl/mtproto/mtp/types.py | 71 +++++++++++-------- .../telethon/_impl/mtproto/transport/abcs.py | 22 +++++- .../_impl/mtproto/transport/abridged.py | 2 +- .../telethon/_impl/mtproto/transport/full.py | 2 +- .../_impl/mtproto/transport/intermediate.py | 2 +- client/src/telethon/_impl/mtsender/errors.py | 11 +++ .../telethon/_impl/mtsender/reconnection.py | 8 +-- client/src/telethon/_impl/mtsender/sender.py | 49 +++++++++---- 10 files changed, 119 insertions(+), 60 deletions(-) create mode 100644 client/src/telethon/_impl/mtsender/errors.py diff --git a/client/src/telethon/_impl/crypto/crypto.py b/client/src/telethon/_impl/crypto/crypto.py index 1c47962e..a5a60a38 100644 --- a/client/src/telethon/_impl/crypto/crypto.py +++ b/client/src/telethon/_impl/crypto/crypto.py @@ -9,17 +9,17 @@ from .auth_key import AuthKey class InvalidBufferError(ValueError): def __init__(self) -> None: - super().__init__("Invalid ciphertext buffer length") + super().__init__("invalid ciphertext buffer length") class AuthKeyMismatchError(ValueError): def __init__(self) -> None: - super().__init__("Server authkey mismatches with ours") + super().__init__("server authkey mismatches with ours") class MsgKeyMismatchError(ValueError): def __init__(self) -> None: - super().__init__("Server msgkey mismatches with ours") + super().__init__("server msgkey mismatches with ours") CryptoError = InvalidBufferError | AuthKeyMismatchError | MsgKeyMismatchError diff --git a/client/src/telethon/_impl/mtproto/mtp/encrypted.py b/client/src/telethon/_impl/mtproto/mtp/encrypted.py index 78d2faa3..a9290fe8 100644 --- a/client/src/telethon/_impl/mtproto/mtp/encrypted.py +++ b/client/src/telethon/_impl/mtproto/mtp/encrypted.py @@ -249,7 +249,7 @@ class Encrypted(Mtp): except struct.error: # If the result is empty, we can't unpack it. # This can happen if the server returns an empty response. - logging.exception("Failed to unpack inner_constructor") + logging.exception("failed to unpack inner_constructor") self._deserialization.append( DeserializationFailure( msg_id=msg_id, @@ -266,7 +266,7 @@ class Encrypted(Mtp): error.msg_id = msg_id self._deserialization.append(error) except Exception: - logging.exception("Failed to deserialize error") + logging.exception("failed to deserialize error") self._deserialization.append( DeserializationFailure( msg_id=msg_id, @@ -287,7 +287,7 @@ class Encrypted(Mtp): self._store_own_updates(body) self._deserialization.append(RpcResult(msg_id, body)) except Exception: - logging.exception("Failed to decompress response") + logging.exception("failed to decompress response") self._deserialization.append( DeserializationFailure(msg_id=msg_id, error=DecompressionFailed()) ) diff --git a/client/src/telethon/_impl/mtproto/mtp/types.py b/client/src/telethon/_impl/mtproto/mtp/types.py index 5594b7bb..4d4b43bc 100644 --- a/client/src/telethon/_impl/mtproto/mtp/types.py +++ b/client/src/telethon/_impl/mtproto/mtp/types.py @@ -181,26 +181,10 @@ class BadMessageError(ValueError): return self._code == other._code -DeserializationError = ValueError - - -class DeserializationFailure: - __slots__ = ("msg_id", "error") - - def __init__(self, msg_id: MsgId, error: DeserializationError) -> None: - self.msg_id = msg_id - self.error = error - - -Deserialization = ( - Update | RpcResult | RpcError | BadMessageError | DeserializationFailure -) - - # Deserialization errors are not fatal, so we don't subclass RpcError. -class BadAuthKeyError(DeserializationError): +class BadAuthKeyError(ValueError): def __init__(self, *args: object, got: int, expected: int) -> None: - super().__init__(f"Bad server auth key (got {got}, expected {expected})", *args) + super().__init__(f"bad server auth key (got {got}, expected {expected})", *args) self._got = got self._expected = expected @@ -213,9 +197,9 @@ class BadAuthKeyError(DeserializationError): return self._expected -class BadMsgIdError(DeserializationError): +class BadMsgIdError(ValueError): def __init__(self, *args: object, got: int) -> None: - super().__init__(f"Bad server message id (got {got})", *args) + super().__init__(f"bad server message id (got {got})", *args) self._got = got @property @@ -223,9 +207,9 @@ class BadMsgIdError(DeserializationError): return self._got -class NegativeLengthError(DeserializationError): +class NegativeLengthError(ValueError): def __init__(self, *args: object, got: int) -> None: - super().__init__(f"Bad server message length (got {got})", *args) + super().__init__(f"bad server message length (got {got})", *args) self._got = got @property @@ -233,10 +217,10 @@ class NegativeLengthError(DeserializationError): return self._got -class TooLongMsgError(DeserializationError): +class TooLongMsgError(ValueError): def __init__(self, *args: object, got: int, max_length: int) -> None: super().__init__( - f"Bad server message length (got {got}, when at most it should be {max_length})", + f"bad server message length (got {got}, when at most it should be {max_length})", *args, ) self._got = got @@ -251,25 +235,25 @@ class TooLongMsgError(DeserializationError): return self._expected -class MsgBufferTooSmall(DeserializationError): +class MsgBufferTooSmall(ValueError): def __init__(self, *args: object) -> None: super().__init__( - "Server responded with a payload that's too small to fit a valid message", + "server responded with a payload that's too small to fit a valid message", *args, ) -class DecompressionFailed(DeserializationError): +class DecompressionFailed(ValueError): def __init__(self, *args: object) -> None: - super().__init__("Failed to decompress server's data", *args) + super().__init__("failed to decompress server's data", *args) -class UnexpectedConstructor(DeserializationError): +class UnexpectedConstructor(ValueError): def __init__(self, *args: object, id: int) -> None: - super().__init__(f"Unexpected constructor: {id:08x}", *args) + super().__init__(f"unexpected constructor: {id:08x}", *args) -class DecryptionError(DeserializationError): +class DecryptionError(ValueError): def __init__(self, *args: object, error: CryptoError) -> None: super().__init__(f"failed to decrypt message: {error}", *args) @@ -280,6 +264,31 @@ class DecryptionError(DeserializationError): return self._error +DeserializationError = ( + BadAuthKeyError + | BadMsgIdError + | NegativeLengthError + | TooLongMsgError + | MsgBufferTooSmall + | DecompressionFailed + | UnexpectedConstructor + | DecryptionError +) + + +class DeserializationFailure: + __slots__ = ("msg_id", "error") + + def __init__(self, msg_id: MsgId, error: DeserializationError) -> None: + self.msg_id = msg_id + self.error = error + + +Deserialization = ( + Update | RpcResult | RpcError | BadMessageError | DeserializationFailure +) + + # https://core.telegram.org/mtproto/description class Mtp(ABC): @abstractmethod diff --git a/client/src/telethon/_impl/mtproto/transport/abcs.py b/client/src/telethon/_impl/mtproto/transport/abcs.py index b3a16f58..5e319655 100644 --- a/client/src/telethon/_impl/mtproto/transport/abcs.py +++ b/client/src/telethon/_impl/mtproto/transport/abcs.py @@ -25,7 +25,27 @@ class MissingBytesError(ValueError): super().__init__(f"missing bytes, expected: {expected}, got: {got}") +class BadLenError(ValueError): + def __init__(self, *, got: int) -> None: + super().__init__(f"bad len (got {got})") + + +class BadSeqError(ValueError): + def __init__(self, *, expected: int, got: int) -> None: + super().__init__(f"bad seq (expected {expected}, got {got})") + + +class BadCrcError(ValueError): + def __init__(self, *, expected: int, got: int) -> None: + super().__init__(f"bad crc (expected {expected}, got {got})") + + class BadStatusError(ValueError): def __init__(self, *, status: int) -> None: - super().__init__(f"transport reported bad status: {status}") + super().__init__(f"bad status (negative length -{status})") self.status = status + + +TransportError = ( + MissingBytesError | BadLenError | BadSeqError | BadCrcError | BadStatusError +) diff --git a/client/src/telethon/_impl/mtproto/transport/abridged.py b/client/src/telethon/_impl/mtproto/transport/abridged.py index 65b92a14..2cd4fc9d 100644 --- a/client/src/telethon/_impl/mtproto/transport/abridged.py +++ b/client/src/telethon/_impl/mtproto/transport/abridged.py @@ -63,5 +63,5 @@ class Abridged(Transport): return header_len + length def reset(self): - logging.info("Resetting sending of header in abridged transport") + logging.info("resetting sending of header in abridged transport") self._init = False diff --git a/client/src/telethon/_impl/mtproto/transport/full.py b/client/src/telethon/_impl/mtproto/transport/full.py index f04d751c..c8a24ed1 100644 --- a/client/src/telethon/_impl/mtproto/transport/full.py +++ b/client/src/telethon/_impl/mtproto/transport/full.py @@ -64,6 +64,6 @@ class Full(Transport): return length def reset(self): - logging.info("Resetting recv and send seqs in full transport") + logging.info("resetting recv and send seqs in full transport") self._send_seq = 0 self._recv_seq = 0 diff --git a/client/src/telethon/_impl/mtproto/transport/intermediate.py b/client/src/telethon/_impl/mtproto/transport/intermediate.py index e75adff6..a533ab9a 100644 --- a/client/src/telethon/_impl/mtproto/transport/intermediate.py +++ b/client/src/telethon/_impl/mtproto/transport/intermediate.py @@ -55,5 +55,5 @@ class Intermediate(Transport): return length + 4 def reset(self): - logging.info("Resetting sending of header in intermediate transport") + logging.info("resetting sending of header in intermediate transport") self._init = False diff --git a/client/src/telethon/_impl/mtsender/errors.py b/client/src/telethon/_impl/mtsender/errors.py new file mode 100644 index 00000000..227d6fdd --- /dev/null +++ b/client/src/telethon/_impl/mtsender/errors.py @@ -0,0 +1,11 @@ +import io + +from ..mtproto.mtp.types import DeserializationError +from ..mtproto.transport.abcs import TransportError + +ReadError = io.BlockingIOError | TransportError | DeserializationError + + +class IOError(io.BlockingIOError): + def __init__(self, *args: object) -> None: + super().__init__(*args) diff --git a/client/src/telethon/_impl/mtsender/reconnection.py b/client/src/telethon/_impl/mtsender/reconnection.py index 2fb51ddd..363aad6d 100644 --- a/client/src/telethon/_impl/mtsender/reconnection.py +++ b/client/src/telethon/_impl/mtsender/reconnection.py @@ -1,4 +1,3 @@ -import time from abc import ABC, abstractmethod @@ -11,7 +10,7 @@ class ReconnectionPolicy(ABC): """ @abstractmethod - def should_retry(self, attempts: int) -> bool: + def should_retry(self, attempts: int) -> bool | float: """ Determines whether the client should retry the connection attempt. """ @@ -30,9 +29,8 @@ class FixedReconnect(ReconnectionPolicy): self.max_attempts = attempts self.delay = delay - def should_retry(self, attempts: int) -> bool: + def should_retry(self, attempts: int) -> bool | float: if attempts < self.max_attempts: - time.sleep(self.delay) - return True + return self.delay return False diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 242f4abd..9e63a4a8 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -291,19 +291,37 @@ class Sender: await self._writer.drain() self._on_net_write() - async def try_connect(self): - # attempts = 0 + async def _try_connect(self): + attempts = 0 ip, port = self.addr.split(":") while True: try: self._reader, self._writer = await self._connector(ip, int(port)) - break + self._logger.info( + f"auto-reconnect success after {attempts} failed attempt(s)" + ) + return except Exception as e: - logging.exception(e) - # TODO: reconnection_policy - break + attempts += 1 + self._logger.warning(f"auto-reconnect failed {attempts} time(s): {e!r}") + await asyncio.sleep(1) + + delay = False + + if self._reconnection_policy is not None: + delay = self._reconnection_policy.should_retry(attempts) + + if delay: + if delay is not True: + await asyncio.sleep(delay) + continue + elif delay is not None: + self._logger.info( + f"waiting {delay} seconds before next reconnection attempt" + ) + await asyncio.sleep(delay) def _try_fill_write(self) -> None: if not self._requests: @@ -362,11 +380,11 @@ class Sender: req.state = Sent(req.state.msg_id, req.state.container_msg_id) def _on_error(self, error: Exception): - logging.info(f"Handling error: {error}") + self._logger.info(f"handling error: {error}") self._transport.reset() self._mtp.reset() - logging.info( - "Resetting sender state from read_buffer {}, mtp_buffer {}".format( + self._logger.info( + "resetting sender state from read_buffer {}, mtp_buffer {}".format( len(self._read_buffer), len(self._mtp_buffer), ) @@ -374,9 +392,12 @@ class Sender: self._read_buffer.clear() self._mtp_buffer.clear() - # TODO: reset + match error: + # TODO + case DeserializationFailure(): + pass - logging.warning( + self._logger.warning( f"marking all {len(self._requests)} request(s) as failed: {error}" ) @@ -495,11 +516,11 @@ class Sender: req = self._pop_request(failure.msg_id) if req: - logging.debug(f"Got deserialization failure {failure.error}") + self._logger.debug(f"got deserialization failure {failure.error}") req.result.set_exception(failure.error) else: - logging.info( - f"Got deserialization failure {failure.error} but no such request is saved" + self._logger.info( + f"got deserialization failure {failure.error} but no such request is saved" ) def _pop_request(self, msg_id: MsgId) -> Optional[Request[object]]: From 4315ddbb57884a0a627a785a1d576dca4c3920ff Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Tue, 3 Jun 2025 09:41:06 +0500 Subject: [PATCH 13/16] Refactor error handling in deserialize method to use specific exception classes --- .../src/telethon/_impl/mtproto/mtp/plain.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/client/src/telethon/_impl/mtproto/mtp/plain.py b/client/src/telethon/_impl/mtproto/mtp/plain.py index 15e86caf..8dff2b19 100644 --- a/client/src/telethon/_impl/mtproto/mtp/plain.py +++ b/client/src/telethon/_impl/mtproto/mtp/plain.py @@ -2,7 +2,16 @@ import struct from typing import Optional from ..utils import check_message_buffer -from .types import Deserialization, MsgId, Mtp, RpcResult +from .types import ( + BadAuthKeyError, + BadMsgIdError, + Deserialization, + MsgId, + Mtp, + NegativeLengthError, + RpcResult, + TooLongMsgError, +) class Plain(Mtp): @@ -38,17 +47,17 @@ class Plain(Mtp): auth_key_id, msg_id, length = struct.unpack_from("= 0, got: {length}") + raise NegativeLengthError(got=length) if 20 + length > (lp := len(payload)): - raise ValueError(f"message too short, expected: {20 + length}, got {lp}") + raise TooLongMsgError(got=length, max_length=lp - 20) return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))] From e90d204287fac90a64a55ab7fc8d7f2648b2f905 Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Tue, 3 Jun 2025 09:45:19 +0500 Subject: [PATCH 14/16] Improve error handling in reconnection logic by logging failures and raising exceptions --- client/src/telethon/_impl/mtsender/sender.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 9e63a4a8..8031f99b 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -317,11 +317,11 @@ class Sender: if delay is not True: await asyncio.sleep(delay) continue - elif delay is not None: - self._logger.info( - f"waiting {delay} seconds before next reconnection attempt" + else: + self._logger.error( + f"auto-reconnect failed {attempts} time(s); giving up" ) - await asyncio.sleep(delay) + raise def _try_fill_write(self) -> None: if not self._requests: From ac611dbbd42a3d072ff4a4671cd965a21101e432 Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Tue, 3 Jun 2025 09:47:54 +0500 Subject: [PATCH 15/16] Refactor error message in Sender class for consistency and clarity --- client/src/telethon/_impl/mtsender/sender.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 8031f99b..f13f52b8 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -422,7 +422,7 @@ class Sender: self._process_deserialize_error(result) else: raise RuntimeError( - f"Unexpected result type {type(result).__name__!r}: {result}" + f"unexpected result type {type(result).__name__}: {result}" ) def _process_update(self, update: bytes | bytearray | memoryview) -> None: From 3bfa64a5d6f305e023fcc7820695c27073fde99a Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Tue, 3 Jun 2025 10:27:43 +0500 Subject: [PATCH 16/16] Add reconnection policy support to Sender and Config classes; refactor error handling in Sender --- .../telethon/_impl/client/client/client.py | 4 ++- .../src/telethon/_impl/client/client/net.py | 2 +- client/src/telethon/_impl/mtsender/errors.py | 9 ++---- client/src/telethon/_impl/mtsender/sender.py | 32 +++++++++++-------- 4 files changed, 24 insertions(+), 23 deletions(-) diff --git a/client/src/telethon/_impl/client/client/client.py b/client/src/telethon/_impl/client/client/client.py index 8fec6a95..91564209 100644 --- a/client/src/telethon/_impl/client/client/client.py +++ b/client/src/telethon/_impl/client/client/client.py @@ -10,6 +10,7 @@ from typing_extensions import Self from ....version import __version__ as default_version from ...mtsender import Connector, Sender +from ...mtsender.reconnection import NoReconnect, ReconnectionPolicy from ...session import ( ChannelRef, ChatHashCache, @@ -215,6 +216,7 @@ class Client: lang_code: Optional[str] = None, datacenter: Optional[DataCenter] = None, connector: Optional[Connector] = None, + reconnection_policy: Optional[ReconnectionPolicy] = None, ) -> None: assert __package__ base_logger = logger or logging.getLogger(__package__[: __package__.index(".")]) @@ -246,7 +248,7 @@ class Client: update_queue_limit=update_queue_limit, base_logger=base_logger, connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)), - reconnection_policy=None, + reconnection_policy=reconnection_policy or NoReconnect(), ) self._session = Session() diff --git a/client/src/telethon/_impl/client/client/net.py b/client/src/telethon/_impl/client/client/net.py index 2770ed9a..e790457d 100644 --- a/client/src/telethon/_impl/client/client/net.py +++ b/client/src/telethon/_impl/client/client/net.py @@ -46,6 +46,7 @@ class Config: api_hash: str base_logger: logging.Logger connector: Connector + reconnection_policy: ReconnectionPolicy device_model: str = field(default_factory=default_device_model) system_version: str = field(default_factory=default_system_version) app_version: str = __version__ @@ -55,7 +56,6 @@ class Config: datacenter: Optional[DataCenter] = None flood_sleep_threshold: int = 60 update_queue_limit: Optional[int] = None - reconnection_policy: Optional[ReconnectionPolicy] = None KNOWN_DCS = [ diff --git a/client/src/telethon/_impl/mtsender/errors.py b/client/src/telethon/_impl/mtsender/errors.py index 227d6fdd..26c94f06 100644 --- a/client/src/telethon/_impl/mtsender/errors.py +++ b/client/src/telethon/_impl/mtsender/errors.py @@ -1,11 +1,6 @@ -import io +from struct import error as struct_error from ..mtproto.mtp.types import DeserializationError from ..mtproto.transport.abcs import TransportError -ReadError = io.BlockingIOError | TransportError | DeserializationError - - -class IOError(io.BlockingIOError): - def __init__(self, *args: object) -> None: - super().__init__(*args) +ReadError = struct_error | TransportError | DeserializationError diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index f13f52b8..f9b8c1f5 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -29,7 +29,7 @@ from ..tl import Request as RemoteCall from ..tl.abcs import Updates from ..tl.core import Serializable from ..tl.mtproto.functions import ping_delay_disconnect -from ..tl.types import UpdateDeleteMessages, UpdateShort +from ..tl.types import UpdateDeleteMessages, UpdateShort, UpdatesTooLong from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages from .reconnection import ReconnectionPolicy @@ -165,7 +165,7 @@ class Sender: dc_id: int addr: str _connector: Connector - _reconnection_policy: Optional[ReconnectionPolicy] + _reconnection_policy: ReconnectionPolicy _logger: logging.Logger _reader: AsyncReader _writer: AsyncWriter @@ -190,7 +190,7 @@ class Sender: addr: str, *, connector: Connector, - reconnection_policy: Optional[ReconnectionPolicy], + reconnection_policy: ReconnectionPolicy, base_logger: logging.Logger, ) -> Self: ip, port = addr.split(":") @@ -249,7 +249,7 @@ class Sender: try: await self._step() except Exception as error: - self._on_error(error) + await self._on_error(error) async def _step(self) -> None: if not self._writing: @@ -308,10 +308,7 @@ class Sender: self._logger.warning(f"auto-reconnect failed {attempts} time(s): {e!r}") await asyncio.sleep(1) - delay = False - - if self._reconnection_policy is not None: - delay = self._reconnection_policy.should_retry(attempts) + delay = self._reconnection_policy.should_retry(attempts) if delay: if delay is not True: @@ -379,7 +376,7 @@ class Sender: if isinstance(req.state, Serialized): req.state = Sent(req.state.msg_id, req.state.container_msg_id) - def _on_error(self, error: Exception): + async def _on_error(self, error: Exception) -> None: self._logger.info(f"handling error: {error}") self._transport.reset() self._mtp.reset() @@ -392,10 +389,17 @@ class Sender: self._read_buffer.clear() self._mtp_buffer.clear() - match error: - # TODO - case DeserializationFailure(): - pass + if isinstance(error, struct.error) and self._reconnection_policy.should_retry( + 0 + ): + self._logger.info(f"read error occurred: {error}") + await self._try_connect() + + for req in self._requests: + req.state = NotSerialized() + + self._updates.append(UpdatesTooLong()) + return self._logger.warning( f"marking all {len(self._requests)} request(s) as failed: {error}" @@ -561,7 +565,7 @@ async def connect( auth_key: Optional[bytes], base_logger: logging.Logger, connector: Connector, - reconnection_policy: Optional[ReconnectionPolicy] = None, + reconnection_policy: ReconnectionPolicy, ) -> Sender: if auth_key is None: sender = await Sender.connect(