Improve error handling in transports

This commit is contained in:
Lonami Exo 2023-10-13 22:59:26 +02:00
parent 42633882b5
commit b4f9d3d720
8 changed files with 64 additions and 13 deletions

View File

@ -3,7 +3,7 @@ from .authentication import step1 as auth_step1
from .authentication import step2 as auth_step2 from .authentication import step2 as auth_step2
from .authentication import step3 as auth_step3 from .authentication import step3 as auth_step3
from .mtp import BadMessage, Deserialization, Encrypted, MsgId, Mtp, Plain, RpcError from .mtp import BadMessage, Deserialization, Encrypted, MsgId, Mtp, Plain, RpcError
from .transport import Abridged, Full, Intermediate, MissingBytes, Transport from .transport import Abridged, BadStatus, Full, Intermediate, MissingBytes, Transport
from .utils import DEFAULT_COMPRESSION_THRESHOLD from .utils import DEFAULT_COMPRESSION_THRESHOLD
__all__ = [ __all__ = [
@ -23,6 +23,7 @@ __all__ = [
"Plain", "Plain",
"RpcError", "RpcError",
"Abridged", "Abridged",
"BadStatus",
"Full", "Full",
"Intermediate", "Intermediate",
"MissingBytes", "MissingBytes",

View File

@ -1,3 +1,4 @@
import logging
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
@ -80,6 +81,25 @@ class RpcError(ValueError):
) )
# https://core.telegram.org/mtproto/service_messages_about_messages
BAD_MSG_DESCRIPTIONS = {
16: "msg_id too low",
17: "msg_id too high",
18: "incorrect two lower order msg_id bits",
19: "container msg_id is the same as msg_id of a previously received message",
20: "message too old, and it cannot be verified whether the server has received a message with this msg_id or not",
32: "msg_seqno too low",
33: "msg_seqno too high",
34: "an even msg_seqno expected, but odd received",
35: "odd msg_seqno expected, but even received",
48: "incorrect server salt",
64: "invalid container",
}
RETRYABLE_MSG_IDS = {16, 17, 48}
NON_FATAL_MSG_IDS = RETRYABLE_MSG_IDS & {32, 33}
class BadMessage(ValueError): class BadMessage(ValueError):
def __init__( def __init__(
self, self,
@ -87,15 +107,27 @@ class BadMessage(ValueError):
code: int, code: int,
caused_by: Optional[int] = None, caused_by: Optional[int] = None,
) -> None: ) -> None:
super().__init__(f"bad msg: {code}") description = BAD_MSG_DESCRIPTIONS.get(code) or "no description available"
super().__init__(f"bad msg={code}: {description}")
self._code = code self._code = code
self._caused_by = caused_by self._caused_by = caused_by
self.severity = (
logging.WARNING if self._code in NON_FATAL_MSG_IDS else logging.ERROR
)
@property @property
def code(self) -> int: def code(self) -> int:
return self._code return self._code
@property
def retryable(self) -> bool:
return self._code in RETRYABLE_MSG_IDS
@property
def fatal(self) -> bool:
return self._code not in NON_FATAL_MSG_IDS
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return NotImplemented return NotImplemented

View File

@ -1,6 +1,6 @@
from .abcs import MissingBytes, Transport from .abcs import BadStatus, MissingBytes, Transport
from .abridged import Abridged from .abridged import Abridged
from .full import Full from .full import Full
from .intermediate import Intermediate from .intermediate import Intermediate
__all__ = ["MissingBytes", "Transport", "Abridged", "Full", "Intermediate"] __all__ = ["BadStatus", "MissingBytes", "Transport", "Abridged", "Full", "Intermediate"]

View File

@ -19,3 +19,9 @@ class Transport(ABC):
class MissingBytes(ValueError): class MissingBytes(ValueError):
def __init__(self, *, expected: int, got: int) -> None: def __init__(self, *, expected: int, got: int) -> None:
super().__init__(f"missing bytes, expected: {expected}, got: {got}") super().__init__(f"missing bytes, expected: {expected}, got: {got}")
class BadStatus(ValueError):
def __init__(self, *, status: int) -> None:
super().__init__(f"transport reported bad status: {status}")
self.status = status

View File

@ -1,6 +1,6 @@
import struct import struct
from .abcs import MissingBytes, OutFn, Transport from .abcs import BadStatus, MissingBytes, OutFn, Transport
class Abridged(Transport): class Abridged(Transport):
@ -41,7 +41,7 @@ class Abridged(Transport):
raise MissingBytes(expected=1, got=0) raise MissingBytes(expected=1, got=0)
length = input[0] length = input[0]
if length < 127: if 1 < length < 127:
header_len = 1 header_len = 1
elif len(input) < 4: elif len(input) < 4:
raise MissingBytes(expected=4, got=len(input)) raise MissingBytes(expected=4, got=len(input))
@ -49,6 +49,11 @@ class Abridged(Transport):
header_len = 4 header_len = 4
length = struct.unpack_from("<i", input)[0] >> 8 length = struct.unpack_from("<i", input)[0] >> 8
if length <= 0:
if length < 0:
raise BadStatus(status=-length)
raise ValueError(f"bad length, expected > 0, got: {length}")
length *= 4 length *= 4
if len(input) < header_len + length: if len(input) < header_len + length:
raise MissingBytes(expected=header_len + length, got=len(input)) raise MissingBytes(expected=header_len + length, got=len(input))

View File

@ -1,7 +1,7 @@
import struct import struct
from zlib import crc32 from zlib import crc32
from .abcs import MissingBytes, OutFn, Transport from .abcs import BadStatus, MissingBytes, OutFn, Transport
class Full(Transport): class Full(Transport):
@ -42,6 +42,8 @@ class Full(Transport):
length = struct.unpack_from("<i", input)[0] length = struct.unpack_from("<i", input)[0]
assert isinstance(length, int) assert isinstance(length, int)
if length < 12: if length < 12:
if length < 0:
raise BadStatus(status=-length)
raise ValueError(f"bad length, expected > 12, got: {length}") raise ValueError(f"bad length, expected > 12, got: {length}")
if len(input) < length: if len(input) < length:

View File

@ -1,6 +1,6 @@
import struct import struct
from .abcs import MissingBytes, OutFn, Transport from .abcs import BadStatus, MissingBytes, OutFn, Transport
class Intermediate(Transport): class Intermediate(Transport):
@ -41,5 +41,14 @@ class Intermediate(Transport):
if len(input) < length: if len(input) < length:
raise MissingBytes(expected=length, got=len(input)) raise MissingBytes(expected=length, got=len(input))
if length <= 4:
if (
length >= 4
and (status := struct.unpack("<i", input[4 : 4 + length])[0]) < 0
):
raise BadStatus(status=-status)
raise ValueError(f"bad length, expected > 0, got: {length}")
output += memoryview(input)[4 : 4 + length] output += memoryview(input)[4 : 4 + length]
return length + 4 return length + 4

View File

@ -1,5 +1,4 @@
import gzip import gzip
import struct
from typing import Optional from typing import Optional
from ..tl.mtproto.types import GzipPacked, Message from ..tl.mtproto.types import GzipPacked, Message
@ -12,10 +11,7 @@ MESSAGE_SIZE_OVERHEAD = 8 + 4 + 4 # msg_id, seq_no, bytes
def check_message_buffer(message: bytes) -> None: def check_message_buffer(message: bytes) -> None:
if len(message) == 4: if len(message) < 20:
neg_http_code = struct.unpack("<i", message)[0]
raise ValueError(f"transport error: {neg_http_code}")
elif len(message) < 20:
raise ValueError( raise ValueError(
f"server payload is too small to be a valid message: {message.hex()}" f"server payload is too small to be a valid message: {message.hex()}"
) )