mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-06-18 11:06:39 +00:00
Refactor error messages for consistency and clarity; update reconnection logic in Sender class
This commit is contained in:
parent
fdf2a05e3e
commit
9c5a6af608
@ -9,17 +9,17 @@ from .auth_key import AuthKey
|
|||||||
|
|
||||||
class InvalidBufferError(ValueError):
|
class InvalidBufferError(ValueError):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__("Invalid ciphertext buffer length")
|
super().__init__("invalid ciphertext buffer length")
|
||||||
|
|
||||||
|
|
||||||
class AuthKeyMismatchError(ValueError):
|
class AuthKeyMismatchError(ValueError):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__("Server authkey mismatches with ours")
|
super().__init__("server authkey mismatches with ours")
|
||||||
|
|
||||||
|
|
||||||
class MsgKeyMismatchError(ValueError):
|
class MsgKeyMismatchError(ValueError):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__("Server msgkey mismatches with ours")
|
super().__init__("server msgkey mismatches with ours")
|
||||||
|
|
||||||
|
|
||||||
CryptoError = InvalidBufferError | AuthKeyMismatchError | MsgKeyMismatchError
|
CryptoError = InvalidBufferError | AuthKeyMismatchError | MsgKeyMismatchError
|
||||||
|
@ -249,7 +249,7 @@ class Encrypted(Mtp):
|
|||||||
except struct.error:
|
except struct.error:
|
||||||
# If the result is empty, we can't unpack it.
|
# If the result is empty, we can't unpack it.
|
||||||
# This can happen if the server returns an empty response.
|
# This can happen if the server returns an empty response.
|
||||||
logging.exception("Failed to unpack inner_constructor")
|
logging.exception("failed to unpack inner_constructor")
|
||||||
self._deserialization.append(
|
self._deserialization.append(
|
||||||
DeserializationFailure(
|
DeserializationFailure(
|
||||||
msg_id=msg_id,
|
msg_id=msg_id,
|
||||||
@ -266,7 +266,7 @@ class Encrypted(Mtp):
|
|||||||
error.msg_id = msg_id
|
error.msg_id = msg_id
|
||||||
self._deserialization.append(error)
|
self._deserialization.append(error)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("Failed to deserialize error")
|
logging.exception("failed to deserialize error")
|
||||||
self._deserialization.append(
|
self._deserialization.append(
|
||||||
DeserializationFailure(
|
DeserializationFailure(
|
||||||
msg_id=msg_id,
|
msg_id=msg_id,
|
||||||
@ -287,7 +287,7 @@ class Encrypted(Mtp):
|
|||||||
self._store_own_updates(body)
|
self._store_own_updates(body)
|
||||||
self._deserialization.append(RpcResult(msg_id, body))
|
self._deserialization.append(RpcResult(msg_id, body))
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("Failed to decompress response")
|
logging.exception("failed to decompress response")
|
||||||
self._deserialization.append(
|
self._deserialization.append(
|
||||||
DeserializationFailure(msg_id=msg_id, error=DecompressionFailed())
|
DeserializationFailure(msg_id=msg_id, error=DecompressionFailed())
|
||||||
)
|
)
|
||||||
|
@ -181,26 +181,10 @@ class BadMessageError(ValueError):
|
|||||||
return self._code == other._code
|
return self._code == other._code
|
||||||
|
|
||||||
|
|
||||||
DeserializationError = ValueError
|
|
||||||
|
|
||||||
|
|
||||||
class DeserializationFailure:
|
|
||||||
__slots__ = ("msg_id", "error")
|
|
||||||
|
|
||||||
def __init__(self, msg_id: MsgId, error: DeserializationError) -> None:
|
|
||||||
self.msg_id = msg_id
|
|
||||||
self.error = error
|
|
||||||
|
|
||||||
|
|
||||||
Deserialization = (
|
|
||||||
Update | RpcResult | RpcError | BadMessageError | DeserializationFailure
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Deserialization errors are not fatal, so we don't subclass RpcError.
|
# Deserialization errors are not fatal, so we don't subclass RpcError.
|
||||||
class BadAuthKeyError(DeserializationError):
|
class BadAuthKeyError(ValueError):
|
||||||
def __init__(self, *args: object, got: int, expected: int) -> None:
|
def __init__(self, *args: object, got: int, expected: int) -> None:
|
||||||
super().__init__(f"Bad server auth key (got {got}, expected {expected})", *args)
|
super().__init__(f"bad server auth key (got {got}, expected {expected})", *args)
|
||||||
self._got = got
|
self._got = got
|
||||||
self._expected = expected
|
self._expected = expected
|
||||||
|
|
||||||
@ -213,9 +197,9 @@ class BadAuthKeyError(DeserializationError):
|
|||||||
return self._expected
|
return self._expected
|
||||||
|
|
||||||
|
|
||||||
class BadMsgIdError(DeserializationError):
|
class BadMsgIdError(ValueError):
|
||||||
def __init__(self, *args: object, got: int) -> None:
|
def __init__(self, *args: object, got: int) -> None:
|
||||||
super().__init__(f"Bad server message id (got {got})", *args)
|
super().__init__(f"bad server message id (got {got})", *args)
|
||||||
self._got = got
|
self._got = got
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -223,9 +207,9 @@ class BadMsgIdError(DeserializationError):
|
|||||||
return self._got
|
return self._got
|
||||||
|
|
||||||
|
|
||||||
class NegativeLengthError(DeserializationError):
|
class NegativeLengthError(ValueError):
|
||||||
def __init__(self, *args: object, got: int) -> None:
|
def __init__(self, *args: object, got: int) -> None:
|
||||||
super().__init__(f"Bad server message length (got {got})", *args)
|
super().__init__(f"bad server message length (got {got})", *args)
|
||||||
self._got = got
|
self._got = got
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -233,10 +217,10 @@ class NegativeLengthError(DeserializationError):
|
|||||||
return self._got
|
return self._got
|
||||||
|
|
||||||
|
|
||||||
class TooLongMsgError(DeserializationError):
|
class TooLongMsgError(ValueError):
|
||||||
def __init__(self, *args: object, got: int, max_length: int) -> None:
|
def __init__(self, *args: object, got: int, max_length: int) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
f"Bad server message length (got {got}, when at most it should be {max_length})",
|
f"bad server message length (got {got}, when at most it should be {max_length})",
|
||||||
*args,
|
*args,
|
||||||
)
|
)
|
||||||
self._got = got
|
self._got = got
|
||||||
@ -251,25 +235,25 @@ class TooLongMsgError(DeserializationError):
|
|||||||
return self._expected
|
return self._expected
|
||||||
|
|
||||||
|
|
||||||
class MsgBufferTooSmall(DeserializationError):
|
class MsgBufferTooSmall(ValueError):
|
||||||
def __init__(self, *args: object) -> None:
|
def __init__(self, *args: object) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
"Server responded with a payload that's too small to fit a valid message",
|
"server responded with a payload that's too small to fit a valid message",
|
||||||
*args,
|
*args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DecompressionFailed(DeserializationError):
|
class DecompressionFailed(ValueError):
|
||||||
def __init__(self, *args: object) -> None:
|
def __init__(self, *args: object) -> None:
|
||||||
super().__init__("Failed to decompress server's data", *args)
|
super().__init__("failed to decompress server's data", *args)
|
||||||
|
|
||||||
|
|
||||||
class UnexpectedConstructor(DeserializationError):
|
class UnexpectedConstructor(ValueError):
|
||||||
def __init__(self, *args: object, id: int) -> None:
|
def __init__(self, *args: object, id: int) -> None:
|
||||||
super().__init__(f"Unexpected constructor: {id:08x}", *args)
|
super().__init__(f"unexpected constructor: {id:08x}", *args)
|
||||||
|
|
||||||
|
|
||||||
class DecryptionError(DeserializationError):
|
class DecryptionError(ValueError):
|
||||||
def __init__(self, *args: object, error: CryptoError) -> None:
|
def __init__(self, *args: object, error: CryptoError) -> None:
|
||||||
super().__init__(f"failed to decrypt message: {error}", *args)
|
super().__init__(f"failed to decrypt message: {error}", *args)
|
||||||
|
|
||||||
@ -280,6 +264,31 @@ class DecryptionError(DeserializationError):
|
|||||||
return self._error
|
return self._error
|
||||||
|
|
||||||
|
|
||||||
|
DeserializationError = (
|
||||||
|
BadAuthKeyError
|
||||||
|
| BadMsgIdError
|
||||||
|
| NegativeLengthError
|
||||||
|
| TooLongMsgError
|
||||||
|
| MsgBufferTooSmall
|
||||||
|
| DecompressionFailed
|
||||||
|
| UnexpectedConstructor
|
||||||
|
| DecryptionError
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeserializationFailure:
|
||||||
|
__slots__ = ("msg_id", "error")
|
||||||
|
|
||||||
|
def __init__(self, msg_id: MsgId, error: DeserializationError) -> None:
|
||||||
|
self.msg_id = msg_id
|
||||||
|
self.error = error
|
||||||
|
|
||||||
|
|
||||||
|
Deserialization = (
|
||||||
|
Update | RpcResult | RpcError | BadMessageError | DeserializationFailure
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# https://core.telegram.org/mtproto/description
|
# https://core.telegram.org/mtproto/description
|
||||||
class Mtp(ABC):
|
class Mtp(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -25,7 +25,27 @@ class MissingBytesError(ValueError):
|
|||||||
super().__init__(f"missing bytes, expected: {expected}, got: {got}")
|
super().__init__(f"missing bytes, expected: {expected}, got: {got}")
|
||||||
|
|
||||||
|
|
||||||
|
class BadLenError(ValueError):
|
||||||
|
def __init__(self, *, got: int) -> None:
|
||||||
|
super().__init__(f"bad len (got {got})")
|
||||||
|
|
||||||
|
|
||||||
|
class BadSeqError(ValueError):
|
||||||
|
def __init__(self, *, expected: int, got: int) -> None:
|
||||||
|
super().__init__(f"bad seq (expected {expected}, got {got})")
|
||||||
|
|
||||||
|
|
||||||
|
class BadCrcError(ValueError):
|
||||||
|
def __init__(self, *, expected: int, got: int) -> None:
|
||||||
|
super().__init__(f"bad crc (expected {expected}, got {got})")
|
||||||
|
|
||||||
|
|
||||||
class BadStatusError(ValueError):
|
class BadStatusError(ValueError):
|
||||||
def __init__(self, *, status: int) -> None:
|
def __init__(self, *, status: int) -> None:
|
||||||
super().__init__(f"transport reported bad status: {status}")
|
super().__init__(f"bad status (negative length -{status})")
|
||||||
self.status = status
|
self.status = status
|
||||||
|
|
||||||
|
|
||||||
|
TransportError = (
|
||||||
|
MissingBytesError | BadLenError | BadSeqError | BadCrcError | BadStatusError
|
||||||
|
)
|
||||||
|
@ -63,5 +63,5 @@ class Abridged(Transport):
|
|||||||
return header_len + length
|
return header_len + length
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
logging.info("Resetting sending of header in abridged transport")
|
logging.info("resetting sending of header in abridged transport")
|
||||||
self._init = False
|
self._init = False
|
||||||
|
@ -64,6 +64,6 @@ class Full(Transport):
|
|||||||
return length
|
return length
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
logging.info("Resetting recv and send seqs in full transport")
|
logging.info("resetting recv and send seqs in full transport")
|
||||||
self._send_seq = 0
|
self._send_seq = 0
|
||||||
self._recv_seq = 0
|
self._recv_seq = 0
|
||||||
|
@ -55,5 +55,5 @@ class Intermediate(Transport):
|
|||||||
return length + 4
|
return length + 4
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
logging.info("Resetting sending of header in intermediate transport")
|
logging.info("resetting sending of header in intermediate transport")
|
||||||
self._init = False
|
self._init = False
|
||||||
|
11
client/src/telethon/_impl/mtsender/errors.py
Normal file
11
client/src/telethon/_impl/mtsender/errors.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import io
|
||||||
|
|
||||||
|
from ..mtproto.mtp.types import DeserializationError
|
||||||
|
from ..mtproto.transport.abcs import TransportError
|
||||||
|
|
||||||
|
ReadError = io.BlockingIOError | TransportError | DeserializationError
|
||||||
|
|
||||||
|
|
||||||
|
class IOError(io.BlockingIOError):
|
||||||
|
def __init__(self, *args: object) -> None:
|
||||||
|
super().__init__(*args)
|
@ -1,4 +1,3 @@
|
|||||||
import time
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
@ -11,7 +10,7 @@ class ReconnectionPolicy(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def should_retry(self, attempts: int) -> bool:
|
def should_retry(self, attempts: int) -> bool | float:
|
||||||
"""
|
"""
|
||||||
Determines whether the client should retry the connection attempt.
|
Determines whether the client should retry the connection attempt.
|
||||||
"""
|
"""
|
||||||
@ -30,9 +29,8 @@ class FixedReconnect(ReconnectionPolicy):
|
|||||||
self.max_attempts = attempts
|
self.max_attempts = attempts
|
||||||
self.delay = delay
|
self.delay = delay
|
||||||
|
|
||||||
def should_retry(self, attempts: int) -> bool:
|
def should_retry(self, attempts: int) -> bool | float:
|
||||||
if attempts < self.max_attempts:
|
if attempts < self.max_attempts:
|
||||||
time.sleep(self.delay)
|
return self.delay
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
@ -291,19 +291,37 @@ class Sender:
|
|||||||
await self._writer.drain()
|
await self._writer.drain()
|
||||||
self._on_net_write()
|
self._on_net_write()
|
||||||
|
|
||||||
async def try_connect(self):
|
async def _try_connect(self):
|
||||||
# attempts = 0
|
attempts = 0
|
||||||
|
|
||||||
ip, port = self.addr.split(":")
|
ip, port = self.addr.split(":")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
self._reader, self._writer = await self._connector(ip, int(port))
|
self._reader, self._writer = await self._connector(ip, int(port))
|
||||||
break
|
self._logger.info(
|
||||||
|
f"auto-reconnect success after {attempts} failed attempt(s)"
|
||||||
|
)
|
||||||
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
attempts += 1
|
||||||
# TODO: reconnection_policy
|
self._logger.warning(f"auto-reconnect failed {attempts} time(s): {e!r}")
|
||||||
break
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
delay = False
|
||||||
|
|
||||||
|
if self._reconnection_policy is not None:
|
||||||
|
delay = self._reconnection_policy.should_retry(attempts)
|
||||||
|
|
||||||
|
if delay:
|
||||||
|
if delay is not True:
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
continue
|
||||||
|
elif delay is not None:
|
||||||
|
self._logger.info(
|
||||||
|
f"waiting {delay} seconds before next reconnection attempt"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
def _try_fill_write(self) -> None:
|
def _try_fill_write(self) -> None:
|
||||||
if not self._requests:
|
if not self._requests:
|
||||||
@ -362,11 +380,11 @@ class Sender:
|
|||||||
req.state = Sent(req.state.msg_id, req.state.container_msg_id)
|
req.state = Sent(req.state.msg_id, req.state.container_msg_id)
|
||||||
|
|
||||||
def _on_error(self, error: Exception):
|
def _on_error(self, error: Exception):
|
||||||
logging.info(f"Handling error: {error}")
|
self._logger.info(f"handling error: {error}")
|
||||||
self._transport.reset()
|
self._transport.reset()
|
||||||
self._mtp.reset()
|
self._mtp.reset()
|
||||||
logging.info(
|
self._logger.info(
|
||||||
"Resetting sender state from read_buffer {}, mtp_buffer {}".format(
|
"resetting sender state from read_buffer {}, mtp_buffer {}".format(
|
||||||
len(self._read_buffer),
|
len(self._read_buffer),
|
||||||
len(self._mtp_buffer),
|
len(self._mtp_buffer),
|
||||||
)
|
)
|
||||||
@ -374,9 +392,12 @@ class Sender:
|
|||||||
self._read_buffer.clear()
|
self._read_buffer.clear()
|
||||||
self._mtp_buffer.clear()
|
self._mtp_buffer.clear()
|
||||||
|
|
||||||
# TODO: reset
|
match error:
|
||||||
|
# TODO
|
||||||
|
case DeserializationFailure():
|
||||||
|
pass
|
||||||
|
|
||||||
logging.warning(
|
self._logger.warning(
|
||||||
f"marking all {len(self._requests)} request(s) as failed: {error}"
|
f"marking all {len(self._requests)} request(s) as failed: {error}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -495,11 +516,11 @@ class Sender:
|
|||||||
req = self._pop_request(failure.msg_id)
|
req = self._pop_request(failure.msg_id)
|
||||||
|
|
||||||
if req:
|
if req:
|
||||||
logging.debug(f"Got deserialization failure {failure.error}")
|
self._logger.debug(f"got deserialization failure {failure.error}")
|
||||||
req.result.set_exception(failure.error)
|
req.result.set_exception(failure.error)
|
||||||
else:
|
else:
|
||||||
logging.info(
|
self._logger.info(
|
||||||
f"Got deserialization failure {failure.error} but no such request is saved"
|
f"got deserialization failure {failure.error} but no such request is saved"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _pop_request(self, msg_id: MsgId) -> Optional[Request[object]]:
|
def _pop_request(self, msg_id: MsgId) -> Optional[Request[object]]:
|
||||||
|
Loading…
Reference in New Issue
Block a user