Refactor error handling in crypto and mtproto modules; introduce custom error classes and improve deserialization error processing in Sender class

This commit is contained in:
Jahongir Qurbonov 2025-06-01 19:55:25 +05:00
parent 13c99d565d
commit 1c2aafce2a
No known key found for this signature in database
GPG Key ID: 256976CED13D5F2D
6 changed files with 206 additions and 25 deletions

View File

@ -7,6 +7,24 @@ from .aes import ige_decrypt, ige_encrypt
from .auth_key import AuthKey from .auth_key import AuthKey
class InvalidBufferError(ValueError):
def __init__(self) -> None:
super().__init__("Invalid ciphertext buffer length")
class AuthKeyMismatchError(ValueError):
def __init__(self) -> None:
super().__init__("Server authkey mismatches with ours")
class MsgKeyMismatchError(ValueError):
def __init__(self) -> None:
super().__init__("Server msgkey mismatches with ours")
CryptoError = InvalidBufferError | AuthKeyMismatchError | MsgKeyMismatchError
# "where x = 0 for messages from client to server and x = 8 for those from server to client" # "where x = 0 for messages from client to server and x = 8 for those from server to client"
class Side(IntEnum): class Side(IntEnum):
CLIENT = 0 CLIENT = 0
@ -77,14 +95,14 @@ def decrypt_data_v2(
x = int(side) x = int(side)
if len(ciphertext) < 24 or (len(ciphertext) - 24) % 16 != 0: if len(ciphertext) < 24 or (len(ciphertext) - 24) % 16 != 0:
raise ValueError("invalid ciphertext buffer length") raise InvalidBufferError()
# salt, session_id and sequence_number should also be checked. # salt, session_id and sequence_number should also be checked.
# However, not doing so has worked fine for years. # However, not doing so has worked fine for years.
key_id = ciphertext[:8] key_id = ciphertext[:8]
if auth_key.key_id != key_id: if auth_key.key_id != key_id:
raise ValueError("server authkey mismatches with ours") raise AuthKeyMismatchError()
msg_key = ciphertext[8:24] msg_key = ciphertext[8:24]
key, iv = calc_key(auth_key, msg_key, side) key, iv = calc_key(auth_key, msg_key, side)
@ -93,7 +111,7 @@ def decrypt_data_v2(
# https://core.telegram.org/mtproto/security_guidelines#mtproto-encrypted-messages # https://core.telegram.org/mtproto/security_guidelines#mtproto-encrypted-messages
our_key = sha256(auth_key.data[x + 88 : x + 120] + plaintext).digest() our_key = sha256(auth_key.data[x + 88 : x + 120] + plaintext).digest()
if msg_key != our_key[8:24]: if msg_key != our_key[8:24]:
raise ValueError("server msgkey mismatches with ours") raise MsgKeyMismatchError()
return plaintext return plaintext

View File

@ -62,11 +62,15 @@ from ..utils import (
) )
from .types import ( from .types import (
BadMessageError, BadMessageError,
DecompressionFailed,
Deserialization, Deserialization,
DeserializationFailure,
MsgBufferTooSmall,
MsgId, MsgId,
Mtp, Mtp,
RpcError, RpcError,
RpcResult, RpcResult,
UnexpectedConstructor,
Update, Update,
) )
@ -85,6 +89,7 @@ UPDATE_IDS = {
AffectedFoundMessages.constructor_id(), AffectedFoundMessages.constructor_id(),
AffectedHistory.constructor_id(), AffectedHistory.constructor_id(),
AffectedMessages.constructor_id(), AffectedMessages.constructor_id(),
# TODO InvitedUsers
} }
HEADER_LEN = 8 + 8 # salt, client_id HEADER_LEN = 8 + 8 # salt, client_id
@ -151,7 +156,7 @@ class Encrypted(Mtp):
self._last_msg_id: int self._last_msg_id: int
self._in_pending_ack: list[int] = [] self._in_pending_ack: list[int] = []
self._msg_count: int self._msg_count: int
self._reset_session() self.reset()
@property @property
def auth_key(self) -> bytes: def auth_key(self) -> bytes:
@ -166,13 +171,6 @@ class Encrypted(Mtp):
def _adjusted_now(self) -> float: def _adjusted_now(self) -> float:
return time.time() + self._time_offset return time.time() + self._time_offset
def _reset_session(self) -> None:
self._client_id = struct.unpack("<q", os.urandom(8))[0]
self._sequence = 0
self._last_msg_id = 0
self._in_pending_ack.clear()
self._msg_count = 0
def _get_new_msg_id(self) -> int: def _get_new_msg_id(self) -> int:
new_msg_id = int(self._adjusted_now() * 0x100000000) new_msg_id = int(self._adjusted_now() * 0x100000000)
if self._last_msg_id >= new_msg_id: if self._last_msg_id >= new_msg_id:
@ -245,12 +243,38 @@ class Encrypted(Mtp):
result = rpc_result.result result = rpc_result.result
msg_id = MsgId(req_msg_id) msg_id = MsgId(req_msg_id)
inner_constructor = struct.unpack_from("<I", result)[0]
try:
inner_constructor = struct.unpack_from("<I", result)[0]
except struct.error as e:
# If the result is empty, we can't unpack it.
# This can happen if the server returns an empty response.
logging.exception(e)
self._deserialization.append(
DeserializationFailure(
msg_id=msg_id,
error=MsgBufferTooSmall(),
)
)
return
if inner_constructor == GeneratedRpcError.constructor_id(): if inner_constructor == GeneratedRpcError.constructor_id():
error = RpcError._from_mtproto_error(GeneratedRpcError.from_bytes(result)) try:
error.msg_id = msg_id error = RpcError._from_mtproto_error(
self._deserialization.append(error) GeneratedRpcError.from_bytes(result)
)
error.msg_id = msg_id
self._deserialization.append(error)
except Exception as e:
logging.exception(e)
self._deserialization.append(
DeserializationFailure(
msg_id=msg_id,
error=UnexpectedConstructor(
id=inner_constructor,
),
)
)
elif inner_constructor == RpcAnswerUnknown.constructor_id(): elif inner_constructor == RpcAnswerUnknown.constructor_id():
pass # msg_id = rpc_drop_answer.msg_id pass # msg_id = rpc_drop_answer.msg_id
elif inner_constructor == RpcAnswerDroppedRunning.constructor_id(): elif inner_constructor == RpcAnswerDroppedRunning.constructor_id():
@ -258,9 +282,15 @@ class Encrypted(Mtp):
elif inner_constructor == RpcAnswerDropped.constructor_id(): elif inner_constructor == RpcAnswerDropped.constructor_id():
pass # dropped pass # dropped
elif inner_constructor == GzipPacked.constructor_id(): elif inner_constructor == GzipPacked.constructor_id():
body = gzip_decompress(GzipPacked.from_bytes(result)) try:
self._store_own_updates(body) body = gzip_decompress(GzipPacked.from_bytes(result))
self._deserialization.append(RpcResult(msg_id, body)) self._store_own_updates(body)
self._deserialization.append(RpcResult(msg_id, body))
except Exception as e:
logging.exception(e)
self._deserialization.append(
DeserializationFailure(msg_id=msg_id, error=DecompressionFailed())
)
else: else:
self._store_own_updates(result) self._store_own_updates(result)
self._deserialization.append(RpcResult(msg_id, result)) self._deserialization.append(RpcResult(msg_id, result))
@ -300,7 +330,7 @@ class Encrypted(Mtp):
elif bad_msg.error_code in (16, 17): elif bad_msg.error_code in (16, 17):
self._correct_time_offset(message.msg_id) self._correct_time_offset(message.msg_id)
elif bad_msg.error_code in (32, 33): elif bad_msg.error_code in (32, 33):
self._reset_session() self.reset()
else: else:
raise exc raise exc
@ -365,6 +395,9 @@ class Encrypted(Mtp):
for inner_message in container.messages: for inner_message in container.messages:
self._process_message(inner_message) self._process_message(inner_message)
def _handle_msg_copy(self, message: Message) -> None:
raise RuntimeError("msg_copy should not be used")
def _handle_gzip_packed(self, message: Message) -> None: def _handle_gzip_packed(self, message: Message) -> None:
container = GzipPacked.from_bytes(message.body) container = GzipPacked.from_bytes(message.body)
inner_body = gzip_decompress(container) inner_body = gzip_decompress(container)
@ -459,3 +492,11 @@ class Encrypted(Mtp):
result = self._deserialization[:] result = self._deserialization[:]
self._deserialization.clear() self._deserialization.clear()
return result return result
def reset(self) -> None:
self._client_id = struct.unpack("<q", os.urandom(8))[0]
self._sequence = 0
self._last_msg_id = 0
self._in_pending_ack.clear()
self._msg_count = 0
self._salt_request_msg_id = None

View File

@ -47,9 +47,10 @@ class Plain(Mtp):
if length < 0: if length < 0:
raise ValueError(f"bad length: expected >= 0, got: {length}") raise ValueError(f"bad length: expected >= 0, got: {length}")
if 20 + length > len(payload): if 20 + length > (lp := len(payload)):
raise ValueError( raise ValueError(f"message too short, expected: {20 + length}, got {lp}")
f"message too short, expected: {20 + length}, got {len(payload)}"
)
return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))] return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))]
def reset(self) -> None:
self._buffer.clear()

View File

@ -5,6 +5,7 @@ from typing import NewType, Optional
from typing_extensions import Self from typing_extensions import Self
from ...crypto.crypto import CryptoError
from ...tl.mtproto.types import RpcError as GeneratedRpcError from ...tl.mtproto.types import RpcError as GeneratedRpcError
MsgId = NewType("MsgId", int) MsgId = NewType("MsgId", int)
@ -180,7 +181,105 @@ class BadMessageError(ValueError):
return self._code == other._code return self._code == other._code
Deserialization = Update | RpcResult | RpcError | BadMessageError DeserializationError = ValueError
class DeserializationFailure:
__slots__ = ("msg_id", "error")
def __init__(self, msg_id: MsgId, error: DeserializationError) -> None:
self.msg_id = msg_id
self.error = error
Deserialization = (
Update | RpcResult | RpcError | BadMessageError | DeserializationFailure
)
# Deserialization errors are not fatal, so we don't subclass RpcError.
class BadAuthKeyError(DeserializationError):
def __init__(self, *args: object, got: int, expected: int) -> None:
super().__init__(f"Bad server auth key (got {got}, expected {expected})", *args)
self._got = got
self._expected = expected
@property
def got(self):
return self._got
@property
def expected(self):
return self._expected
class BadMsgIdError(DeserializationError):
def __init__(self, *args: object, got: int) -> None:
super().__init__(f"Bad server message id (got {got})", *args)
self._got = got
@property
def got(self):
return self._got
class NegativeLengthError(DeserializationError):
def __init__(self, *args: object, got: int) -> None:
super().__init__(f"Bad server message length (got {got})", *args)
self._got = got
@property
def got(self):
return self._got
class TooLongMsgError(DeserializationError):
__slots__ = ("expected", "got")
def __init__(self, *args: object, got: int, max_length: int) -> None:
super().__init__(
f"Bad server message length (got {got}, when at most it should be {max_length})",
*args,
)
self._got = got
self._expected = max_length
@property
def got(self):
return self._got
@property
def expected(self):
return self._expected
class MsgBufferTooSmall(DeserializationError):
def __init__(self, *args: object) -> None:
super().__init__(
"Server responded with a payload that's too small to fit a valid message",
*args,
)
class DecompressionFailed(DeserializationError):
def __init__(self, *args: object) -> None:
super().__init__("Failed to decompress server's data", *args)
class UnexpectedConstructor(DeserializationError):
def __init__(self, *args: object, id: int) -> None:
super().__init__(f"Unexpected constructor: {id:08x}", *args)
class DecryptionError(DeserializationError):
def __init__(self, *args: object, error: CryptoError) -> None:
super().__init__(f"failed to decrypt message: {error}", *args)
self._error = error
@property
def error(self):
return self._error
# https://core.telegram.org/mtproto/description # https://core.telegram.org/mtproto/description
@ -209,3 +308,9 @@ class Mtp(ABC):
""" """
Deserialize incoming buffer payload. Deserialize incoming buffer payload.
""" """
@abstractmethod
def reset(self) -> None:
"""
Reset the internal buffer.
"""

View File

@ -24,6 +24,7 @@ from ..mtproto import (
Update, Update,
authentication, authentication,
) )
from ..mtproto.mtp.types import DeserializationFailure
from ..tl import Request as RemoteCall from ..tl import Request as RemoteCall
from ..tl.abcs import Updates from ..tl.abcs import Updates
from ..tl.core import Serializable from ..tl.core import Serializable
@ -334,8 +335,12 @@ class Sender:
self._process_result(result) self._process_result(result)
elif isinstance(result, RpcError): elif isinstance(result, RpcError):
self._process_error(result) self._process_error(result)
else: elif isinstance(result, BadMessageError):
self._process_bad_message(result) self._process_bad_message(result)
elif isinstance(result, DeserializationFailure):
self._process_deserialize_error(result)
else:
raise RuntimeError(f"Unexpected result: {result}")
def _process_update(self, update: bytes | bytearray | memoryview) -> None: def _process_update(self, update: bytes | bytearray | memoryview) -> None:
try: try:
@ -424,6 +429,17 @@ class Sender:
result._caused_by = struct.unpack_from("<I", req.body)[0] result._caused_by = struct.unpack_from("<I", req.body)[0]
req.result.set_exception(result) req.result.set_exception(result)
def _process_deserialize_error(self, failure: DeserializationFailure):
req = self._pop_request(failure.msg_id)
if req:
logging.debug(f"Got deserialization failure {failure.error}")
req.result.set_exception(failure.error)
else:
logging.info(
f"Got deserialization failure {failure.error} but no such request is saved"
)
def _pop_request(self, msg_id: MsgId) -> Optional[Request[object]]: def _pop_request(self, msg_id: MsgId) -> Optional[Request[object]]:
for req in self._requests: for req in self._requests:
if isinstance(req.state, Serialized) and req.state.msg_id == msg_id: if isinstance(req.state, Serialized) and req.state.msg_id == msg_id: