diff --git a/client/src/telethon/_impl/crypto/crypto.py b/client/src/telethon/_impl/crypto/crypto.py index 1c47962e..a5a60a38 100644 --- a/client/src/telethon/_impl/crypto/crypto.py +++ b/client/src/telethon/_impl/crypto/crypto.py @@ -9,17 +9,17 @@ from .auth_key import AuthKey class InvalidBufferError(ValueError): def __init__(self) -> None: - super().__init__("Invalid ciphertext buffer length") + super().__init__("invalid ciphertext buffer length") class AuthKeyMismatchError(ValueError): def __init__(self) -> None: - super().__init__("Server authkey mismatches with ours") + super().__init__("server authkey mismatches with ours") class MsgKeyMismatchError(ValueError): def __init__(self) -> None: - super().__init__("Server msgkey mismatches with ours") + super().__init__("server msgkey mismatches with ours") CryptoError = InvalidBufferError | AuthKeyMismatchError | MsgKeyMismatchError diff --git a/client/src/telethon/_impl/mtproto/mtp/encrypted.py b/client/src/telethon/_impl/mtproto/mtp/encrypted.py index 78d2faa3..a9290fe8 100644 --- a/client/src/telethon/_impl/mtproto/mtp/encrypted.py +++ b/client/src/telethon/_impl/mtproto/mtp/encrypted.py @@ -249,7 +249,7 @@ class Encrypted(Mtp): except struct.error: # If the result is empty, we can't unpack it. # This can happen if the server returns an empty response. - logging.exception("Failed to unpack inner_constructor") + logging.exception("failed to unpack inner_constructor") self._deserialization.append( DeserializationFailure( msg_id=msg_id, @@ -266,7 +266,7 @@ class Encrypted(Mtp): error.msg_id = msg_id self._deserialization.append(error) except Exception: - logging.exception("Failed to deserialize error") + logging.exception("failed to deserialize error") self._deserialization.append( DeserializationFailure( msg_id=msg_id, @@ -287,7 +287,7 @@ class Encrypted(Mtp): self._store_own_updates(body) self._deserialization.append(RpcResult(msg_id, body)) except Exception: - logging.exception("Failed to decompress response") + logging.exception("failed to decompress response") self._deserialization.append( DeserializationFailure(msg_id=msg_id, error=DecompressionFailed()) ) diff --git a/client/src/telethon/_impl/mtproto/mtp/types.py b/client/src/telethon/_impl/mtproto/mtp/types.py index 5594b7bb..4d4b43bc 100644 --- a/client/src/telethon/_impl/mtproto/mtp/types.py +++ b/client/src/telethon/_impl/mtproto/mtp/types.py @@ -181,26 +181,10 @@ class BadMessageError(ValueError): return self._code == other._code -DeserializationError = ValueError - - -class DeserializationFailure: - __slots__ = ("msg_id", "error") - - def __init__(self, msg_id: MsgId, error: DeserializationError) -> None: - self.msg_id = msg_id - self.error = error - - -Deserialization = ( - Update | RpcResult | RpcError | BadMessageError | DeserializationFailure -) - - # Deserialization errors are not fatal, so we don't subclass RpcError. -class BadAuthKeyError(DeserializationError): +class BadAuthKeyError(ValueError): def __init__(self, *args: object, got: int, expected: int) -> None: - super().__init__(f"Bad server auth key (got {got}, expected {expected})", *args) + super().__init__(f"bad server auth key (got {got}, expected {expected})", *args) self._got = got self._expected = expected @@ -213,9 +197,9 @@ class BadAuthKeyError(DeserializationError): return self._expected -class BadMsgIdError(DeserializationError): +class BadMsgIdError(ValueError): def __init__(self, *args: object, got: int) -> None: - super().__init__(f"Bad server message id (got {got})", *args) + super().__init__(f"bad server message id (got {got})", *args) self._got = got @property @@ -223,9 +207,9 @@ class BadMsgIdError(DeserializationError): return self._got -class NegativeLengthError(DeserializationError): +class NegativeLengthError(ValueError): def __init__(self, *args: object, got: int) -> None: - super().__init__(f"Bad server message length (got {got})", *args) + super().__init__(f"bad server message length (got {got})", *args) self._got = got @property @@ -233,10 +217,10 @@ class NegativeLengthError(DeserializationError): return self._got -class TooLongMsgError(DeserializationError): +class TooLongMsgError(ValueError): def __init__(self, *args: object, got: int, max_length: int) -> None: super().__init__( - f"Bad server message length (got {got}, when at most it should be {max_length})", + f"bad server message length (got {got}, when at most it should be {max_length})", *args, ) self._got = got @@ -251,25 +235,25 @@ class TooLongMsgError(DeserializationError): return self._expected -class MsgBufferTooSmall(DeserializationError): +class MsgBufferTooSmall(ValueError): def __init__(self, *args: object) -> None: super().__init__( - "Server responded with a payload that's too small to fit a valid message", + "server responded with a payload that's too small to fit a valid message", *args, ) -class DecompressionFailed(DeserializationError): +class DecompressionFailed(ValueError): def __init__(self, *args: object) -> None: - super().__init__("Failed to decompress server's data", *args) + super().__init__("failed to decompress server's data", *args) -class UnexpectedConstructor(DeserializationError): +class UnexpectedConstructor(ValueError): def __init__(self, *args: object, id: int) -> None: - super().__init__(f"Unexpected constructor: {id:08x}", *args) + super().__init__(f"unexpected constructor: {id:08x}", *args) -class DecryptionError(DeserializationError): +class DecryptionError(ValueError): def __init__(self, *args: object, error: CryptoError) -> None: super().__init__(f"failed to decrypt message: {error}", *args) @@ -280,6 +264,31 @@ class DecryptionError(DeserializationError): return self._error +DeserializationError = ( + BadAuthKeyError + | BadMsgIdError + | NegativeLengthError + | TooLongMsgError + | MsgBufferTooSmall + | DecompressionFailed + | UnexpectedConstructor + | DecryptionError +) + + +class DeserializationFailure: + __slots__ = ("msg_id", "error") + + def __init__(self, msg_id: MsgId, error: DeserializationError) -> None: + self.msg_id = msg_id + self.error = error + + +Deserialization = ( + Update | RpcResult | RpcError | BadMessageError | DeserializationFailure +) + + # https://core.telegram.org/mtproto/description class Mtp(ABC): @abstractmethod diff --git a/client/src/telethon/_impl/mtproto/transport/abcs.py b/client/src/telethon/_impl/mtproto/transport/abcs.py index b3a16f58..5e319655 100644 --- a/client/src/telethon/_impl/mtproto/transport/abcs.py +++ b/client/src/telethon/_impl/mtproto/transport/abcs.py @@ -25,7 +25,27 @@ class MissingBytesError(ValueError): super().__init__(f"missing bytes, expected: {expected}, got: {got}") +class BadLenError(ValueError): + def __init__(self, *, got: int) -> None: + super().__init__(f"bad len (got {got})") + + +class BadSeqError(ValueError): + def __init__(self, *, expected: int, got: int) -> None: + super().__init__(f"bad seq (expected {expected}, got {got})") + + +class BadCrcError(ValueError): + def __init__(self, *, expected: int, got: int) -> None: + super().__init__(f"bad crc (expected {expected}, got {got})") + + class BadStatusError(ValueError): def __init__(self, *, status: int) -> None: - super().__init__(f"transport reported bad status: {status}") + super().__init__(f"bad status (negative length -{status})") self.status = status + + +TransportError = ( + MissingBytesError | BadLenError | BadSeqError | BadCrcError | BadStatusError +) diff --git a/client/src/telethon/_impl/mtproto/transport/abridged.py b/client/src/telethon/_impl/mtproto/transport/abridged.py index 65b92a14..2cd4fc9d 100644 --- a/client/src/telethon/_impl/mtproto/transport/abridged.py +++ b/client/src/telethon/_impl/mtproto/transport/abridged.py @@ -63,5 +63,5 @@ class Abridged(Transport): return header_len + length def reset(self): - logging.info("Resetting sending of header in abridged transport") + logging.info("resetting sending of header in abridged transport") self._init = False diff --git a/client/src/telethon/_impl/mtproto/transport/full.py b/client/src/telethon/_impl/mtproto/transport/full.py index f04d751c..c8a24ed1 100644 --- a/client/src/telethon/_impl/mtproto/transport/full.py +++ b/client/src/telethon/_impl/mtproto/transport/full.py @@ -64,6 +64,6 @@ class Full(Transport): return length def reset(self): - logging.info("Resetting recv and send seqs in full transport") + logging.info("resetting recv and send seqs in full transport") self._send_seq = 0 self._recv_seq = 0 diff --git a/client/src/telethon/_impl/mtproto/transport/intermediate.py b/client/src/telethon/_impl/mtproto/transport/intermediate.py index e75adff6..a533ab9a 100644 --- a/client/src/telethon/_impl/mtproto/transport/intermediate.py +++ b/client/src/telethon/_impl/mtproto/transport/intermediate.py @@ -55,5 +55,5 @@ class Intermediate(Transport): return length + 4 def reset(self): - logging.info("Resetting sending of header in intermediate transport") + logging.info("resetting sending of header in intermediate transport") self._init = False diff --git a/client/src/telethon/_impl/mtsender/errors.py b/client/src/telethon/_impl/mtsender/errors.py new file mode 100644 index 00000000..227d6fdd --- /dev/null +++ b/client/src/telethon/_impl/mtsender/errors.py @@ -0,0 +1,11 @@ +import io + +from ..mtproto.mtp.types import DeserializationError +from ..mtproto.transport.abcs import TransportError + +ReadError = io.BlockingIOError | TransportError | DeserializationError + + +class IOError(io.BlockingIOError): + def __init__(self, *args: object) -> None: + super().__init__(*args) diff --git a/client/src/telethon/_impl/mtsender/reconnection.py b/client/src/telethon/_impl/mtsender/reconnection.py index 2fb51ddd..363aad6d 100644 --- a/client/src/telethon/_impl/mtsender/reconnection.py +++ b/client/src/telethon/_impl/mtsender/reconnection.py @@ -1,4 +1,3 @@ -import time from abc import ABC, abstractmethod @@ -11,7 +10,7 @@ class ReconnectionPolicy(ABC): """ @abstractmethod - def should_retry(self, attempts: int) -> bool: + def should_retry(self, attempts: int) -> bool | float: """ Determines whether the client should retry the connection attempt. """ @@ -30,9 +29,8 @@ class FixedReconnect(ReconnectionPolicy): self.max_attempts = attempts self.delay = delay - def should_retry(self, attempts: int) -> bool: + def should_retry(self, attempts: int) -> bool | float: if attempts < self.max_attempts: - time.sleep(self.delay) - return True + return self.delay return False diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 242f4abd..9e63a4a8 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -291,19 +291,37 @@ class Sender: await self._writer.drain() self._on_net_write() - async def try_connect(self): - # attempts = 0 + async def _try_connect(self): + attempts = 0 ip, port = self.addr.split(":") while True: try: self._reader, self._writer = await self._connector(ip, int(port)) - break + self._logger.info( + f"auto-reconnect success after {attempts} failed attempt(s)" + ) + return except Exception as e: - logging.exception(e) - # TODO: reconnection_policy - break + attempts += 1 + self._logger.warning(f"auto-reconnect failed {attempts} time(s): {e!r}") + await asyncio.sleep(1) + + delay = False + + if self._reconnection_policy is not None: + delay = self._reconnection_policy.should_retry(attempts) + + if delay: + if delay is not True: + await asyncio.sleep(delay) + continue + elif delay is not None: + self._logger.info( + f"waiting {delay} seconds before next reconnection attempt" + ) + await asyncio.sleep(delay) def _try_fill_write(self) -> None: if not self._requests: @@ -362,11 +380,11 @@ class Sender: req.state = Sent(req.state.msg_id, req.state.container_msg_id) def _on_error(self, error: Exception): - logging.info(f"Handling error: {error}") + self._logger.info(f"handling error: {error}") self._transport.reset() self._mtp.reset() - logging.info( - "Resetting sender state from read_buffer {}, mtp_buffer {}".format( + self._logger.info( + "resetting sender state from read_buffer {}, mtp_buffer {}".format( len(self._read_buffer), len(self._mtp_buffer), ) @@ -374,9 +392,12 @@ class Sender: self._read_buffer.clear() self._mtp_buffer.clear() - # TODO: reset + match error: + # TODO + case DeserializationFailure(): + pass - logging.warning( + self._logger.warning( f"marking all {len(self._requests)} request(s) as failed: {error}" ) @@ -495,11 +516,11 @@ class Sender: req = self._pop_request(failure.msg_id) if req: - logging.debug(f"Got deserialization failure {failure.error}") + self._logger.debug(f"got deserialization failure {failure.error}") req.result.set_exception(failure.error) else: - logging.info( - f"Got deserialization failure {failure.error} but no such request is saved" + self._logger.info( + f"got deserialization failure {failure.error} but no such request is saved" ) def _pop_request(self, msg_id: MsgId) -> Optional[Request[object]]: