mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-06-18 02:56:39 +00:00
Refactor error messages for consistency and clarity; update reconnection logic in Sender class
This commit is contained in:
parent
fdf2a05e3e
commit
9c5a6af608
@ -9,17 +9,17 @@ from .auth_key import AuthKey
|
||||
|
||||
class InvalidBufferError(ValueError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Invalid ciphertext buffer length")
|
||||
super().__init__("invalid ciphertext buffer length")
|
||||
|
||||
|
||||
class AuthKeyMismatchError(ValueError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Server authkey mismatches with ours")
|
||||
super().__init__("server authkey mismatches with ours")
|
||||
|
||||
|
||||
class MsgKeyMismatchError(ValueError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Server msgkey mismatches with ours")
|
||||
super().__init__("server msgkey mismatches with ours")
|
||||
|
||||
|
||||
CryptoError = InvalidBufferError | AuthKeyMismatchError | MsgKeyMismatchError
|
||||
|
@ -249,7 +249,7 @@ class Encrypted(Mtp):
|
||||
except struct.error:
|
||||
# If the result is empty, we can't unpack it.
|
||||
# This can happen if the server returns an empty response.
|
||||
logging.exception("Failed to unpack inner_constructor")
|
||||
logging.exception("failed to unpack inner_constructor")
|
||||
self._deserialization.append(
|
||||
DeserializationFailure(
|
||||
msg_id=msg_id,
|
||||
@ -266,7 +266,7 @@ class Encrypted(Mtp):
|
||||
error.msg_id = msg_id
|
||||
self._deserialization.append(error)
|
||||
except Exception:
|
||||
logging.exception("Failed to deserialize error")
|
||||
logging.exception("failed to deserialize error")
|
||||
self._deserialization.append(
|
||||
DeserializationFailure(
|
||||
msg_id=msg_id,
|
||||
@ -287,7 +287,7 @@ class Encrypted(Mtp):
|
||||
self._store_own_updates(body)
|
||||
self._deserialization.append(RpcResult(msg_id, body))
|
||||
except Exception:
|
||||
logging.exception("Failed to decompress response")
|
||||
logging.exception("failed to decompress response")
|
||||
self._deserialization.append(
|
||||
DeserializationFailure(msg_id=msg_id, error=DecompressionFailed())
|
||||
)
|
||||
|
@ -181,26 +181,10 @@ class BadMessageError(ValueError):
|
||||
return self._code == other._code
|
||||
|
||||
|
||||
DeserializationError = ValueError
|
||||
|
||||
|
||||
class DeserializationFailure:
|
||||
__slots__ = ("msg_id", "error")
|
||||
|
||||
def __init__(self, msg_id: MsgId, error: DeserializationError) -> None:
|
||||
self.msg_id = msg_id
|
||||
self.error = error
|
||||
|
||||
|
||||
Deserialization = (
|
||||
Update | RpcResult | RpcError | BadMessageError | DeserializationFailure
|
||||
)
|
||||
|
||||
|
||||
# Deserialization errors are not fatal, so we don't subclass RpcError.
|
||||
class BadAuthKeyError(DeserializationError):
|
||||
class BadAuthKeyError(ValueError):
|
||||
def __init__(self, *args: object, got: int, expected: int) -> None:
|
||||
super().__init__(f"Bad server auth key (got {got}, expected {expected})", *args)
|
||||
super().__init__(f"bad server auth key (got {got}, expected {expected})", *args)
|
||||
self._got = got
|
||||
self._expected = expected
|
||||
|
||||
@ -213,9 +197,9 @@ class BadAuthKeyError(DeserializationError):
|
||||
return self._expected
|
||||
|
||||
|
||||
class BadMsgIdError(DeserializationError):
|
||||
class BadMsgIdError(ValueError):
|
||||
def __init__(self, *args: object, got: int) -> None:
|
||||
super().__init__(f"Bad server message id (got {got})", *args)
|
||||
super().__init__(f"bad server message id (got {got})", *args)
|
||||
self._got = got
|
||||
|
||||
@property
|
||||
@ -223,9 +207,9 @@ class BadMsgIdError(DeserializationError):
|
||||
return self._got
|
||||
|
||||
|
||||
class NegativeLengthError(DeserializationError):
|
||||
class NegativeLengthError(ValueError):
|
||||
def __init__(self, *args: object, got: int) -> None:
|
||||
super().__init__(f"Bad server message length (got {got})", *args)
|
||||
super().__init__(f"bad server message length (got {got})", *args)
|
||||
self._got = got
|
||||
|
||||
@property
|
||||
@ -233,10 +217,10 @@ class NegativeLengthError(DeserializationError):
|
||||
return self._got
|
||||
|
||||
|
||||
class TooLongMsgError(DeserializationError):
|
||||
class TooLongMsgError(ValueError):
|
||||
def __init__(self, *args: object, got: int, max_length: int) -> None:
|
||||
super().__init__(
|
||||
f"Bad server message length (got {got}, when at most it should be {max_length})",
|
||||
f"bad server message length (got {got}, when at most it should be {max_length})",
|
||||
*args,
|
||||
)
|
||||
self._got = got
|
||||
@ -251,25 +235,25 @@ class TooLongMsgError(DeserializationError):
|
||||
return self._expected
|
||||
|
||||
|
||||
class MsgBufferTooSmall(DeserializationError):
|
||||
class MsgBufferTooSmall(ValueError):
|
||||
def __init__(self, *args: object) -> None:
|
||||
super().__init__(
|
||||
"Server responded with a payload that's too small to fit a valid message",
|
||||
"server responded with a payload that's too small to fit a valid message",
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
class DecompressionFailed(DeserializationError):
|
||||
class DecompressionFailed(ValueError):
|
||||
def __init__(self, *args: object) -> None:
|
||||
super().__init__("Failed to decompress server's data", *args)
|
||||
super().__init__("failed to decompress server's data", *args)
|
||||
|
||||
|
||||
class UnexpectedConstructor(DeserializationError):
|
||||
class UnexpectedConstructor(ValueError):
|
||||
def __init__(self, *args: object, id: int) -> None:
|
||||
super().__init__(f"Unexpected constructor: {id:08x}", *args)
|
||||
super().__init__(f"unexpected constructor: {id:08x}", *args)
|
||||
|
||||
|
||||
class DecryptionError(DeserializationError):
|
||||
class DecryptionError(ValueError):
|
||||
def __init__(self, *args: object, error: CryptoError) -> None:
|
||||
super().__init__(f"failed to decrypt message: {error}", *args)
|
||||
|
||||
@ -280,6 +264,31 @@ class DecryptionError(DeserializationError):
|
||||
return self._error
|
||||
|
||||
|
||||
DeserializationError = (
|
||||
BadAuthKeyError
|
||||
| BadMsgIdError
|
||||
| NegativeLengthError
|
||||
| TooLongMsgError
|
||||
| MsgBufferTooSmall
|
||||
| DecompressionFailed
|
||||
| UnexpectedConstructor
|
||||
| DecryptionError
|
||||
)
|
||||
|
||||
|
||||
class DeserializationFailure:
|
||||
__slots__ = ("msg_id", "error")
|
||||
|
||||
def __init__(self, msg_id: MsgId, error: DeserializationError) -> None:
|
||||
self.msg_id = msg_id
|
||||
self.error = error
|
||||
|
||||
|
||||
Deserialization = (
|
||||
Update | RpcResult | RpcError | BadMessageError | DeserializationFailure
|
||||
)
|
||||
|
||||
|
||||
# https://core.telegram.org/mtproto/description
|
||||
class Mtp(ABC):
|
||||
@abstractmethod
|
||||
|
@ -25,7 +25,27 @@ class MissingBytesError(ValueError):
|
||||
super().__init__(f"missing bytes, expected: {expected}, got: {got}")
|
||||
|
||||
|
||||
class BadLenError(ValueError):
|
||||
def __init__(self, *, got: int) -> None:
|
||||
super().__init__(f"bad len (got {got})")
|
||||
|
||||
|
||||
class BadSeqError(ValueError):
|
||||
def __init__(self, *, expected: int, got: int) -> None:
|
||||
super().__init__(f"bad seq (expected {expected}, got {got})")
|
||||
|
||||
|
||||
class BadCrcError(ValueError):
|
||||
def __init__(self, *, expected: int, got: int) -> None:
|
||||
super().__init__(f"bad crc (expected {expected}, got {got})")
|
||||
|
||||
|
||||
class BadStatusError(ValueError):
|
||||
def __init__(self, *, status: int) -> None:
|
||||
super().__init__(f"transport reported bad status: {status}")
|
||||
super().__init__(f"bad status (negative length -{status})")
|
||||
self.status = status
|
||||
|
||||
|
||||
TransportError = (
|
||||
MissingBytesError | BadLenError | BadSeqError | BadCrcError | BadStatusError
|
||||
)
|
||||
|
@ -63,5 +63,5 @@ class Abridged(Transport):
|
||||
return header_len + length
|
||||
|
||||
def reset(self):
|
||||
logging.info("Resetting sending of header in abridged transport")
|
||||
logging.info("resetting sending of header in abridged transport")
|
||||
self._init = False
|
||||
|
@ -64,6 +64,6 @@ class Full(Transport):
|
||||
return length
|
||||
|
||||
def reset(self):
|
||||
logging.info("Resetting recv and send seqs in full transport")
|
||||
logging.info("resetting recv and send seqs in full transport")
|
||||
self._send_seq = 0
|
||||
self._recv_seq = 0
|
||||
|
@ -55,5 +55,5 @@ class Intermediate(Transport):
|
||||
return length + 4
|
||||
|
||||
def reset(self):
|
||||
logging.info("Resetting sending of header in intermediate transport")
|
||||
logging.info("resetting sending of header in intermediate transport")
|
||||
self._init = False
|
||||
|
11
client/src/telethon/_impl/mtsender/errors.py
Normal file
11
client/src/telethon/_impl/mtsender/errors.py
Normal file
@ -0,0 +1,11 @@
|
||||
import io
|
||||
|
||||
from ..mtproto.mtp.types import DeserializationError
|
||||
from ..mtproto.transport.abcs import TransportError
|
||||
|
||||
ReadError = io.BlockingIOError | TransportError | DeserializationError
|
||||
|
||||
|
||||
class IOError(io.BlockingIOError):
|
||||
def __init__(self, *args: object) -> None:
|
||||
super().__init__(*args)
|
@ -1,4 +1,3 @@
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
@ -11,7 +10,7 @@ class ReconnectionPolicy(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def should_retry(self, attempts: int) -> bool:
|
||||
def should_retry(self, attempts: int) -> bool | float:
|
||||
"""
|
||||
Determines whether the client should retry the connection attempt.
|
||||
"""
|
||||
@ -30,9 +29,8 @@ class FixedReconnect(ReconnectionPolicy):
|
||||
self.max_attempts = attempts
|
||||
self.delay = delay
|
||||
|
||||
def should_retry(self, attempts: int) -> bool:
|
||||
def should_retry(self, attempts: int) -> bool | float:
|
||||
if attempts < self.max_attempts:
|
||||
time.sleep(self.delay)
|
||||
return True
|
||||
return self.delay
|
||||
|
||||
return False
|
||||
|
@ -291,19 +291,37 @@ class Sender:
|
||||
await self._writer.drain()
|
||||
self._on_net_write()
|
||||
|
||||
async def try_connect(self):
|
||||
# attempts = 0
|
||||
async def _try_connect(self):
|
||||
attempts = 0
|
||||
|
||||
ip, port = self.addr.split(":")
|
||||
|
||||
while True:
|
||||
try:
|
||||
self._reader, self._writer = await self._connector(ip, int(port))
|
||||
break
|
||||
self._logger.info(
|
||||
f"auto-reconnect success after {attempts} failed attempt(s)"
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
# TODO: reconnection_policy
|
||||
break
|
||||
attempts += 1
|
||||
self._logger.warning(f"auto-reconnect failed {attempts} time(s): {e!r}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
delay = False
|
||||
|
||||
if self._reconnection_policy is not None:
|
||||
delay = self._reconnection_policy.should_retry(attempts)
|
||||
|
||||
if delay:
|
||||
if delay is not True:
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
elif delay is not None:
|
||||
self._logger.info(
|
||||
f"waiting {delay} seconds before next reconnection attempt"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
def _try_fill_write(self) -> None:
|
||||
if not self._requests:
|
||||
@ -362,11 +380,11 @@ class Sender:
|
||||
req.state = Sent(req.state.msg_id, req.state.container_msg_id)
|
||||
|
||||
def _on_error(self, error: Exception):
|
||||
logging.info(f"Handling error: {error}")
|
||||
self._logger.info(f"handling error: {error}")
|
||||
self._transport.reset()
|
||||
self._mtp.reset()
|
||||
logging.info(
|
||||
"Resetting sender state from read_buffer {}, mtp_buffer {}".format(
|
||||
self._logger.info(
|
||||
"resetting sender state from read_buffer {}, mtp_buffer {}".format(
|
||||
len(self._read_buffer),
|
||||
len(self._mtp_buffer),
|
||||
)
|
||||
@ -374,9 +392,12 @@ class Sender:
|
||||
self._read_buffer.clear()
|
||||
self._mtp_buffer.clear()
|
||||
|
||||
# TODO: reset
|
||||
match error:
|
||||
# TODO
|
||||
case DeserializationFailure():
|
||||
pass
|
||||
|
||||
logging.warning(
|
||||
self._logger.warning(
|
||||
f"marking all {len(self._requests)} request(s) as failed: {error}"
|
||||
)
|
||||
|
||||
@ -495,11 +516,11 @@ class Sender:
|
||||
req = self._pop_request(failure.msg_id)
|
||||
|
||||
if req:
|
||||
logging.debug(f"Got deserialization failure {failure.error}")
|
||||
self._logger.debug(f"got deserialization failure {failure.error}")
|
||||
req.result.set_exception(failure.error)
|
||||
else:
|
||||
logging.info(
|
||||
f"Got deserialization failure {failure.error} but no such request is saved"
|
||||
self._logger.info(
|
||||
f"got deserialization failure {failure.error} but no such request is saved"
|
||||
)
|
||||
|
||||
def _pop_request(self, msg_id: MsgId) -> Optional[Request[object]]:
|
||||
|
Loading…
Reference in New Issue
Block a user