Refactor error messages for consistency and clarity; update reconnection logic in Sender class

This commit is contained in:
Jahongir Qurbonov 2025-06-02 16:34:02 +05:00
parent fdf2a05e3e
commit 9c5a6af608
No known key found for this signature in database
GPG Key ID: 256976CED13D5F2D
10 changed files with 119 additions and 60 deletions

View File

@ -9,17 +9,17 @@ from .auth_key import AuthKey
class InvalidBufferError(ValueError):
def __init__(self) -> None:
super().__init__("Invalid ciphertext buffer length")
super().__init__("invalid ciphertext buffer length")
class AuthKeyMismatchError(ValueError):
def __init__(self) -> None:
super().__init__("Server authkey mismatches with ours")
super().__init__("server authkey mismatches with ours")
class MsgKeyMismatchError(ValueError):
def __init__(self) -> None:
super().__init__("Server msgkey mismatches with ours")
super().__init__("server msgkey mismatches with ours")
CryptoError = InvalidBufferError | AuthKeyMismatchError | MsgKeyMismatchError

View File

@ -249,7 +249,7 @@ class Encrypted(Mtp):
except struct.error:
# If the result is empty, we can't unpack it.
# This can happen if the server returns an empty response.
logging.exception("Failed to unpack inner_constructor")
logging.exception("failed to unpack inner_constructor")
self._deserialization.append(
DeserializationFailure(
msg_id=msg_id,
@ -266,7 +266,7 @@ class Encrypted(Mtp):
error.msg_id = msg_id
self._deserialization.append(error)
except Exception:
logging.exception("Failed to deserialize error")
logging.exception("failed to deserialize error")
self._deserialization.append(
DeserializationFailure(
msg_id=msg_id,
@ -287,7 +287,7 @@ class Encrypted(Mtp):
self._store_own_updates(body)
self._deserialization.append(RpcResult(msg_id, body))
except Exception:
logging.exception("Failed to decompress response")
logging.exception("failed to decompress response")
self._deserialization.append(
DeserializationFailure(msg_id=msg_id, error=DecompressionFailed())
)

View File

@ -181,26 +181,10 @@ class BadMessageError(ValueError):
return self._code == other._code
DeserializationError = ValueError
class DeserializationFailure:
__slots__ = ("msg_id", "error")
def __init__(self, msg_id: MsgId, error: DeserializationError) -> None:
self.msg_id = msg_id
self.error = error
Deserialization = (
Update | RpcResult | RpcError | BadMessageError | DeserializationFailure
)
# Deserialization errors are not fatal, so we don't subclass RpcError.
class BadAuthKeyError(DeserializationError):
class BadAuthKeyError(ValueError):
def __init__(self, *args: object, got: int, expected: int) -> None:
super().__init__(f"Bad server auth key (got {got}, expected {expected})", *args)
super().__init__(f"bad server auth key (got {got}, expected {expected})", *args)
self._got = got
self._expected = expected
@ -213,9 +197,9 @@ class BadAuthKeyError(DeserializationError):
return self._expected
class BadMsgIdError(DeserializationError):
class BadMsgIdError(ValueError):
def __init__(self, *args: object, got: int) -> None:
super().__init__(f"Bad server message id (got {got})", *args)
super().__init__(f"bad server message id (got {got})", *args)
self._got = got
@property
@ -223,9 +207,9 @@ class BadMsgIdError(DeserializationError):
return self._got
class NegativeLengthError(DeserializationError):
class NegativeLengthError(ValueError):
def __init__(self, *args: object, got: int) -> None:
super().__init__(f"Bad server message length (got {got})", *args)
super().__init__(f"bad server message length (got {got})", *args)
self._got = got
@property
@ -233,10 +217,10 @@ class NegativeLengthError(DeserializationError):
return self._got
class TooLongMsgError(DeserializationError):
class TooLongMsgError(ValueError):
def __init__(self, *args: object, got: int, max_length: int) -> None:
super().__init__(
f"Bad server message length (got {got}, when at most it should be {max_length})",
f"bad server message length (got {got}, when at most it should be {max_length})",
*args,
)
self._got = got
@ -251,25 +235,25 @@ class TooLongMsgError(DeserializationError):
return self._expected
class MsgBufferTooSmall(DeserializationError):
class MsgBufferTooSmall(ValueError):
def __init__(self, *args: object) -> None:
super().__init__(
"Server responded with a payload that's too small to fit a valid message",
"server responded with a payload that's too small to fit a valid message",
*args,
)
class DecompressionFailed(DeserializationError):
class DecompressionFailed(ValueError):
def __init__(self, *args: object) -> None:
super().__init__("Failed to decompress server's data", *args)
super().__init__("failed to decompress server's data", *args)
class UnexpectedConstructor(DeserializationError):
class UnexpectedConstructor(ValueError):
def __init__(self, *args: object, id: int) -> None:
super().__init__(f"Unexpected constructor: {id:08x}", *args)
super().__init__(f"unexpected constructor: {id:08x}", *args)
class DecryptionError(DeserializationError):
class DecryptionError(ValueError):
def __init__(self, *args: object, error: CryptoError) -> None:
super().__init__(f"failed to decrypt message: {error}", *args)
@ -280,6 +264,31 @@ class DecryptionError(DeserializationError):
return self._error
DeserializationError = (
BadAuthKeyError
| BadMsgIdError
| NegativeLengthError
| TooLongMsgError
| MsgBufferTooSmall
| DecompressionFailed
| UnexpectedConstructor
| DecryptionError
)
class DeserializationFailure:
__slots__ = ("msg_id", "error")
def __init__(self, msg_id: MsgId, error: DeserializationError) -> None:
self.msg_id = msg_id
self.error = error
Deserialization = (
Update | RpcResult | RpcError | BadMessageError | DeserializationFailure
)
# https://core.telegram.org/mtproto/description
class Mtp(ABC):
@abstractmethod

View File

@ -25,7 +25,27 @@ class MissingBytesError(ValueError):
super().__init__(f"missing bytes, expected: {expected}, got: {got}")
class BadLenError(ValueError):
def __init__(self, *, got: int) -> None:
super().__init__(f"bad len (got {got})")
class BadSeqError(ValueError):
def __init__(self, *, expected: int, got: int) -> None:
super().__init__(f"bad seq (expected {expected}, got {got})")
class BadCrcError(ValueError):
def __init__(self, *, expected: int, got: int) -> None:
super().__init__(f"bad crc (expected {expected}, got {got})")
class BadStatusError(ValueError):
def __init__(self, *, status: int) -> None:
super().__init__(f"transport reported bad status: {status}")
super().__init__(f"bad status (negative length -{status})")
self.status = status
TransportError = (
MissingBytesError | BadLenError | BadSeqError | BadCrcError | BadStatusError
)

View File

@ -63,5 +63,5 @@ class Abridged(Transport):
return header_len + length
def reset(self):
logging.info("Resetting sending of header in abridged transport")
logging.info("resetting sending of header in abridged transport")
self._init = False

View File

@ -64,6 +64,6 @@ class Full(Transport):
return length
def reset(self):
logging.info("Resetting recv and send seqs in full transport")
logging.info("resetting recv and send seqs in full transport")
self._send_seq = 0
self._recv_seq = 0

View File

@ -55,5 +55,5 @@ class Intermediate(Transport):
return length + 4
def reset(self):
logging.info("Resetting sending of header in intermediate transport")
logging.info("resetting sending of header in intermediate transport")
self._init = False

View File

@ -0,0 +1,11 @@
import io
from ..mtproto.mtp.types import DeserializationError
from ..mtproto.transport.abcs import TransportError
ReadError = io.BlockingIOError | TransportError | DeserializationError
class IOError(io.BlockingIOError):
def __init__(self, *args: object) -> None:
super().__init__(*args)

View File

@ -1,4 +1,3 @@
import time
from abc import ABC, abstractmethod
@ -11,7 +10,7 @@ class ReconnectionPolicy(ABC):
"""
@abstractmethod
def should_retry(self, attempts: int) -> bool:
def should_retry(self, attempts: int) -> bool | float:
"""
Determines whether the client should retry the connection attempt.
"""
@ -30,9 +29,8 @@ class FixedReconnect(ReconnectionPolicy):
self.max_attempts = attempts
self.delay = delay
def should_retry(self, attempts: int) -> bool:
def should_retry(self, attempts: int) -> bool | float:
if attempts < self.max_attempts:
time.sleep(self.delay)
return True
return self.delay
return False

View File

@ -291,19 +291,37 @@ class Sender:
await self._writer.drain()
self._on_net_write()
async def try_connect(self):
# attempts = 0
async def _try_connect(self):
attempts = 0
ip, port = self.addr.split(":")
while True:
try:
self._reader, self._writer = await self._connector(ip, int(port))
break
self._logger.info(
f"auto-reconnect success after {attempts} failed attempt(s)"
)
return
except Exception as e:
logging.exception(e)
# TODO: reconnection_policy
break
attempts += 1
self._logger.warning(f"auto-reconnect failed {attempts} time(s): {e!r}")
await asyncio.sleep(1)
delay = False
if self._reconnection_policy is not None:
delay = self._reconnection_policy.should_retry(attempts)
if delay:
if delay is not True:
await asyncio.sleep(delay)
continue
elif delay is not None:
self._logger.info(
f"waiting {delay} seconds before next reconnection attempt"
)
await asyncio.sleep(delay)
def _try_fill_write(self) -> None:
if not self._requests:
@ -362,11 +380,11 @@ class Sender:
req.state = Sent(req.state.msg_id, req.state.container_msg_id)
def _on_error(self, error: Exception):
logging.info(f"Handling error: {error}")
self._logger.info(f"handling error: {error}")
self._transport.reset()
self._mtp.reset()
logging.info(
"Resetting sender state from read_buffer {}, mtp_buffer {}".format(
self._logger.info(
"resetting sender state from read_buffer {}, mtp_buffer {}".format(
len(self._read_buffer),
len(self._mtp_buffer),
)
@ -374,9 +392,12 @@ class Sender:
self._read_buffer.clear()
self._mtp_buffer.clear()
# TODO: reset
match error:
# TODO
case DeserializationFailure():
pass
logging.warning(
self._logger.warning(
f"marking all {len(self._requests)} request(s) as failed: {error}"
)
@ -495,11 +516,11 @@ class Sender:
req = self._pop_request(failure.msg_id)
if req:
logging.debug(f"Got deserialization failure {failure.error}")
self._logger.debug(f"got deserialization failure {failure.error}")
req.result.set_exception(failure.error)
else:
logging.info(
f"Got deserialization failure {failure.error} but no such request is saved"
self._logger.info(
f"got deserialization failure {failure.error} but no such request is saved"
)
def _pop_request(self, msg_id: MsgId) -> Optional[Request[object]]: