mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-06-17 10:36:37 +00:00
Improve error handling in transports
This commit is contained in:
parent
42633882b5
commit
b4f9d3d720
@ -3,7 +3,7 @@ from .authentication import step1 as auth_step1
|
|||||||
from .authentication import step2 as auth_step2
|
from .authentication import step2 as auth_step2
|
||||||
from .authentication import step3 as auth_step3
|
from .authentication import step3 as auth_step3
|
||||||
from .mtp import BadMessage, Deserialization, Encrypted, MsgId, Mtp, Plain, RpcError
|
from .mtp import BadMessage, Deserialization, Encrypted, MsgId, Mtp, Plain, RpcError
|
||||||
from .transport import Abridged, Full, Intermediate, MissingBytes, Transport
|
from .transport import Abridged, BadStatus, Full, Intermediate, MissingBytes, Transport
|
||||||
from .utils import DEFAULT_COMPRESSION_THRESHOLD
|
from .utils import DEFAULT_COMPRESSION_THRESHOLD
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -23,6 +23,7 @@ __all__ = [
|
|||||||
"Plain",
|
"Plain",
|
||||||
"RpcError",
|
"RpcError",
|
||||||
"Abridged",
|
"Abridged",
|
||||||
|
"BadStatus",
|
||||||
"Full",
|
"Full",
|
||||||
"Intermediate",
|
"Intermediate",
|
||||||
"MissingBytes",
|
"MissingBytes",
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -80,6 +81,25 @@ class RpcError(ValueError):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# https://core.telegram.org/mtproto/service_messages_about_messages
|
||||||
|
BAD_MSG_DESCRIPTIONS = {
|
||||||
|
16: "msg_id too low",
|
||||||
|
17: "msg_id too high",
|
||||||
|
18: "incorrect two lower order msg_id bits",
|
||||||
|
19: "container msg_id is the same as msg_id of a previously received message",
|
||||||
|
20: "message too old, and it cannot be verified whether the server has received a message with this msg_id or not",
|
||||||
|
32: "msg_seqno too low",
|
||||||
|
33: "msg_seqno too high",
|
||||||
|
34: "an even msg_seqno expected, but odd received",
|
||||||
|
35: "odd msg_seqno expected, but even received",
|
||||||
|
48: "incorrect server salt",
|
||||||
|
64: "invalid container",
|
||||||
|
}
|
||||||
|
|
||||||
|
RETRYABLE_MSG_IDS = {16, 17, 48}
|
||||||
|
NON_FATAL_MSG_IDS = RETRYABLE_MSG_IDS & {32, 33}
|
||||||
|
|
||||||
|
|
||||||
class BadMessage(ValueError):
|
class BadMessage(ValueError):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -87,15 +107,27 @@ class BadMessage(ValueError):
|
|||||||
code: int,
|
code: int,
|
||||||
caused_by: Optional[int] = None,
|
caused_by: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(f"bad msg: {code}")
|
description = BAD_MSG_DESCRIPTIONS.get(code) or "no description available"
|
||||||
|
super().__init__(f"bad msg={code}: {description}")
|
||||||
|
|
||||||
self._code = code
|
self._code = code
|
||||||
self._caused_by = caused_by
|
self._caused_by = caused_by
|
||||||
|
self.severity = (
|
||||||
|
logging.WARNING if self._code in NON_FATAL_MSG_IDS else logging.ERROR
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def code(self) -> int:
|
def code(self) -> int:
|
||||||
return self._code
|
return self._code
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retryable(self) -> bool:
|
||||||
|
return self._code in RETRYABLE_MSG_IDS
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fatal(self) -> bool:
|
||||||
|
return self._code not in NON_FATAL_MSG_IDS
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if not isinstance(other, self.__class__):
|
if not isinstance(other, self.__class__):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from .abcs import MissingBytes, Transport
|
from .abcs import BadStatus, MissingBytes, Transport
|
||||||
from .abridged import Abridged
|
from .abridged import Abridged
|
||||||
from .full import Full
|
from .full import Full
|
||||||
from .intermediate import Intermediate
|
from .intermediate import Intermediate
|
||||||
|
|
||||||
__all__ = ["MissingBytes", "Transport", "Abridged", "Full", "Intermediate"]
|
__all__ = ["BadStatus", "MissingBytes", "Transport", "Abridged", "Full", "Intermediate"]
|
||||||
|
@ -19,3 +19,9 @@ class Transport(ABC):
|
|||||||
class MissingBytes(ValueError):
|
class MissingBytes(ValueError):
|
||||||
def __init__(self, *, expected: int, got: int) -> None:
|
def __init__(self, *, expected: int, got: int) -> None:
|
||||||
super().__init__(f"missing bytes, expected: {expected}, got: {got}")
|
super().__init__(f"missing bytes, expected: {expected}, got: {got}")
|
||||||
|
|
||||||
|
|
||||||
|
class BadStatus(ValueError):
|
||||||
|
def __init__(self, *, status: int) -> None:
|
||||||
|
super().__init__(f"transport reported bad status: {status}")
|
||||||
|
self.status = status
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import struct
|
import struct
|
||||||
|
|
||||||
from .abcs import MissingBytes, OutFn, Transport
|
from .abcs import BadStatus, MissingBytes, OutFn, Transport
|
||||||
|
|
||||||
|
|
||||||
class Abridged(Transport):
|
class Abridged(Transport):
|
||||||
@ -41,7 +41,7 @@ class Abridged(Transport):
|
|||||||
raise MissingBytes(expected=1, got=0)
|
raise MissingBytes(expected=1, got=0)
|
||||||
|
|
||||||
length = input[0]
|
length = input[0]
|
||||||
if length < 127:
|
if 1 < length < 127:
|
||||||
header_len = 1
|
header_len = 1
|
||||||
elif len(input) < 4:
|
elif len(input) < 4:
|
||||||
raise MissingBytes(expected=4, got=len(input))
|
raise MissingBytes(expected=4, got=len(input))
|
||||||
@ -49,6 +49,11 @@ class Abridged(Transport):
|
|||||||
header_len = 4
|
header_len = 4
|
||||||
length = struct.unpack_from("<i", input)[0] >> 8
|
length = struct.unpack_from("<i", input)[0] >> 8
|
||||||
|
|
||||||
|
if length <= 0:
|
||||||
|
if length < 0:
|
||||||
|
raise BadStatus(status=-length)
|
||||||
|
raise ValueError(f"bad length, expected > 0, got: {length}")
|
||||||
|
|
||||||
length *= 4
|
length *= 4
|
||||||
if len(input) < header_len + length:
|
if len(input) < header_len + length:
|
||||||
raise MissingBytes(expected=header_len + length, got=len(input))
|
raise MissingBytes(expected=header_len + length, got=len(input))
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import struct
|
import struct
|
||||||
from zlib import crc32
|
from zlib import crc32
|
||||||
|
|
||||||
from .abcs import MissingBytes, OutFn, Transport
|
from .abcs import BadStatus, MissingBytes, OutFn, Transport
|
||||||
|
|
||||||
|
|
||||||
class Full(Transport):
|
class Full(Transport):
|
||||||
@ -42,6 +42,8 @@ class Full(Transport):
|
|||||||
length = struct.unpack_from("<i", input)[0]
|
length = struct.unpack_from("<i", input)[0]
|
||||||
assert isinstance(length, int)
|
assert isinstance(length, int)
|
||||||
if length < 12:
|
if length < 12:
|
||||||
|
if length < 0:
|
||||||
|
raise BadStatus(status=-length)
|
||||||
raise ValueError(f"bad length, expected > 12, got: {length}")
|
raise ValueError(f"bad length, expected > 12, got: {length}")
|
||||||
|
|
||||||
if len(input) < length:
|
if len(input) < length:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import struct
|
import struct
|
||||||
|
|
||||||
from .abcs import MissingBytes, OutFn, Transport
|
from .abcs import BadStatus, MissingBytes, OutFn, Transport
|
||||||
|
|
||||||
|
|
||||||
class Intermediate(Transport):
|
class Intermediate(Transport):
|
||||||
@ -41,5 +41,14 @@ class Intermediate(Transport):
|
|||||||
if len(input) < length:
|
if len(input) < length:
|
||||||
raise MissingBytes(expected=length, got=len(input))
|
raise MissingBytes(expected=length, got=len(input))
|
||||||
|
|
||||||
|
if length <= 4:
|
||||||
|
if (
|
||||||
|
length >= 4
|
||||||
|
and (status := struct.unpack("<i", input[4 : 4 + length])[0]) < 0
|
||||||
|
):
|
||||||
|
raise BadStatus(status=-status)
|
||||||
|
|
||||||
|
raise ValueError(f"bad length, expected > 0, got: {length}")
|
||||||
|
|
||||||
output += memoryview(input)[4 : 4 + length]
|
output += memoryview(input)[4 : 4 + length]
|
||||||
return length + 4
|
return length + 4
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import gzip
|
import gzip
|
||||||
import struct
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..tl.mtproto.types import GzipPacked, Message
|
from ..tl.mtproto.types import GzipPacked, Message
|
||||||
@ -12,10 +11,7 @@ MESSAGE_SIZE_OVERHEAD = 8 + 4 + 4 # msg_id, seq_no, bytes
|
|||||||
|
|
||||||
|
|
||||||
def check_message_buffer(message: bytes) -> None:
|
def check_message_buffer(message: bytes) -> None:
|
||||||
if len(message) == 4:
|
if len(message) < 20:
|
||||||
neg_http_code = struct.unpack("<i", message)[0]
|
|
||||||
raise ValueError(f"transport error: {neg_http_code}")
|
|
||||||
elif len(message) < 20:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"server payload is too small to be a valid message: {message.hex()}"
|
f"server payload is too small to be a valid message: {message.hex()}"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user