Refactor error handling in crypto and mtproto modules; introduce custom error classes and improve deserialization error processing in Sender class

This commit is contained in:
Jahongir Qurbonov 2025-06-01 19:55:25 +05:00
parent 13c99d565d
commit 1c2aafce2a
No known key found for this signature in database
GPG Key ID: 256976CED13D5F2D
6 changed files with 206 additions and 25 deletions

View File

@ -7,6 +7,24 @@ from .aes import ige_decrypt, ige_encrypt
from .auth_key import AuthKey
class InvalidBufferError(ValueError):
def __init__(self) -> None:
super().__init__("Invalid ciphertext buffer length")
class AuthKeyMismatchError(ValueError):
def __init__(self) -> None:
super().__init__("Server authkey mismatches with ours")
class MsgKeyMismatchError(ValueError):
def __init__(self) -> None:
super().__init__("Server msgkey mismatches with ours")
CryptoError = InvalidBufferError | AuthKeyMismatchError | MsgKeyMismatchError
# "where x = 0 for messages from client to server and x = 8 for those from server to client"
class Side(IntEnum):
CLIENT = 0
@ -77,14 +95,14 @@ def decrypt_data_v2(
x = int(side)
if len(ciphertext) < 24 or (len(ciphertext) - 24) % 16 != 0:
raise ValueError("invalid ciphertext buffer length")
raise InvalidBufferError()
# salt, session_id and sequence_number should also be checked.
# However, not doing so has worked fine for years.
key_id = ciphertext[:8]
if auth_key.key_id != key_id:
raise ValueError("server authkey mismatches with ours")
raise AuthKeyMismatchError()
msg_key = ciphertext[8:24]
key, iv = calc_key(auth_key, msg_key, side)
@ -93,7 +111,7 @@ def decrypt_data_v2(
# https://core.telegram.org/mtproto/security_guidelines#mtproto-encrypted-messages
our_key = sha256(auth_key.data[x + 88 : x + 120] + plaintext).digest()
if msg_key != our_key[8:24]:
raise ValueError("server msgkey mismatches with ours")
raise MsgKeyMismatchError()
return plaintext

View File

@ -62,11 +62,15 @@ from ..utils import (
)
from .types import (
BadMessageError,
DecompressionFailed,
Deserialization,
DeserializationFailure,
MsgBufferTooSmall,
MsgId,
Mtp,
RpcError,
RpcResult,
UnexpectedConstructor,
Update,
)
@ -85,6 +89,7 @@ UPDATE_IDS = {
AffectedFoundMessages.constructor_id(),
AffectedHistory.constructor_id(),
AffectedMessages.constructor_id(),
# TODO InvitedUsers
}
HEADER_LEN = 8 + 8 # salt, client_id
@ -151,7 +156,7 @@ class Encrypted(Mtp):
self._last_msg_id: int
self._in_pending_ack: list[int] = []
self._msg_count: int
self._reset_session()
self.reset()
@property
def auth_key(self) -> bytes:
@ -166,13 +171,6 @@ class Encrypted(Mtp):
def _adjusted_now(self) -> float:
return time.time() + self._time_offset
def _reset_session(self) -> None:
self._client_id = struct.unpack("<q", os.urandom(8))[0]
self._sequence = 0
self._last_msg_id = 0
self._in_pending_ack.clear()
self._msg_count = 0
def _get_new_msg_id(self) -> int:
new_msg_id = int(self._adjusted_now() * 0x100000000)
if self._last_msg_id >= new_msg_id:
@ -245,12 +243,38 @@ class Encrypted(Mtp):
result = rpc_result.result
msg_id = MsgId(req_msg_id)
inner_constructor = struct.unpack_from("<I", result)[0]
try:
inner_constructor = struct.unpack_from("<I", result)[0]
except struct.error as e:
# If the result is empty, we can't unpack it.
# This can happen if the server returns an empty response.
logging.exception(e)
self._deserialization.append(
DeserializationFailure(
msg_id=msg_id,
error=MsgBufferTooSmall(),
)
)
return
if inner_constructor == GeneratedRpcError.constructor_id():
error = RpcError._from_mtproto_error(GeneratedRpcError.from_bytes(result))
error.msg_id = msg_id
self._deserialization.append(error)
try:
error = RpcError._from_mtproto_error(
GeneratedRpcError.from_bytes(result)
)
error.msg_id = msg_id
self._deserialization.append(error)
except Exception as e:
logging.exception(e)
self._deserialization.append(
DeserializationFailure(
msg_id=msg_id,
error=UnexpectedConstructor(
id=inner_constructor,
),
)
)
elif inner_constructor == RpcAnswerUnknown.constructor_id():
pass # msg_id = rpc_drop_answer.msg_id
elif inner_constructor == RpcAnswerDroppedRunning.constructor_id():
@ -258,9 +282,15 @@ class Encrypted(Mtp):
elif inner_constructor == RpcAnswerDropped.constructor_id():
pass # dropped
elif inner_constructor == GzipPacked.constructor_id():
body = gzip_decompress(GzipPacked.from_bytes(result))
self._store_own_updates(body)
self._deserialization.append(RpcResult(msg_id, body))
try:
body = gzip_decompress(GzipPacked.from_bytes(result))
self._store_own_updates(body)
self._deserialization.append(RpcResult(msg_id, body))
except Exception as e:
logging.exception(e)
self._deserialization.append(
DeserializationFailure(msg_id=msg_id, error=DecompressionFailed())
)
else:
self._store_own_updates(result)
self._deserialization.append(RpcResult(msg_id, result))
@ -300,7 +330,7 @@ class Encrypted(Mtp):
elif bad_msg.error_code in (16, 17):
self._correct_time_offset(message.msg_id)
elif bad_msg.error_code in (32, 33):
self._reset_session()
self.reset()
else:
raise exc
@ -365,6 +395,9 @@ class Encrypted(Mtp):
for inner_message in container.messages:
self._process_message(inner_message)
def _handle_msg_copy(self, message: Message) -> None:
raise RuntimeError("msg_copy should not be used")
def _handle_gzip_packed(self, message: Message) -> None:
container = GzipPacked.from_bytes(message.body)
inner_body = gzip_decompress(container)
@ -459,3 +492,11 @@ class Encrypted(Mtp):
result = self._deserialization[:]
self._deserialization.clear()
return result
def reset(self) -> None:
self._client_id = struct.unpack("<q", os.urandom(8))[0]
self._sequence = 0
self._last_msg_id = 0
self._in_pending_ack.clear()
self._msg_count = 0
self._salt_request_msg_id = None

View File

@ -47,9 +47,10 @@ class Plain(Mtp):
if length < 0:
raise ValueError(f"bad length: expected >= 0, got: {length}")
if 20 + length > len(payload):
raise ValueError(
f"message too short, expected: {20 + length}, got {len(payload)}"
)
if 20 + length > (lp := len(payload)):
raise ValueError(f"message too short, expected: {20 + length}, got {lp}")
return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))]
def reset(self) -> None:
self._buffer.clear()

View File

@ -5,6 +5,7 @@ from typing import NewType, Optional
from typing_extensions import Self
from ...crypto.crypto import CryptoError
from ...tl.mtproto.types import RpcError as GeneratedRpcError
MsgId = NewType("MsgId", int)
@ -180,7 +181,105 @@ class BadMessageError(ValueError):
return self._code == other._code
Deserialization = Update | RpcResult | RpcError | BadMessageError
DeserializationError = ValueError
class DeserializationFailure:
__slots__ = ("msg_id", "error")
def __init__(self, msg_id: MsgId, error: DeserializationError) -> None:
self.msg_id = msg_id
self.error = error
Deserialization = (
Update | RpcResult | RpcError | BadMessageError | DeserializationFailure
)
# Deserialization errors are not fatal, so we don't subclass RpcError.
class BadAuthKeyError(DeserializationError):
def __init__(self, *args: object, got: int, expected: int) -> None:
super().__init__(f"Bad server auth key (got {got}, expected {expected})", *args)
self._got = got
self._expected = expected
@property
def got(self):
return self._got
@property
def expected(self):
return self._expected
class BadMsgIdError(DeserializationError):
def __init__(self, *args: object, got: int) -> None:
super().__init__(f"Bad server message id (got {got})", *args)
self._got = got
@property
def got(self):
return self._got
class NegativeLengthError(DeserializationError):
def __init__(self, *args: object, got: int) -> None:
super().__init__(f"Bad server message length (got {got})", *args)
self._got = got
@property
def got(self):
return self._got
class TooLongMsgError(DeserializationError):
__slots__ = ("expected", "got")
def __init__(self, *args: object, got: int, max_length: int) -> None:
super().__init__(
f"Bad server message length (got {got}, when at most it should be {max_length})",
*args,
)
self._got = got
self._expected = max_length
@property
def got(self):
return self._got
@property
def expected(self):
return self._expected
class MsgBufferTooSmall(DeserializationError):
def __init__(self, *args: object) -> None:
super().__init__(
"Server responded with a payload that's too small to fit a valid message",
*args,
)
class DecompressionFailed(DeserializationError):
def __init__(self, *args: object) -> None:
super().__init__("Failed to decompress server's data", *args)
class UnexpectedConstructor(DeserializationError):
def __init__(self, *args: object, id: int) -> None:
super().__init__(f"Unexpected constructor: {id:08x}", *args)
class DecryptionError(DeserializationError):
def __init__(self, *args: object, error: CryptoError) -> None:
super().__init__(f"failed to decrypt message: {error}", *args)
self._error = error
@property
def error(self):
return self._error
# https://core.telegram.org/mtproto/description
@ -209,3 +308,9 @@ class Mtp(ABC):
"""
Deserialize incoming buffer payload.
"""
@abstractmethod
def reset(self) -> None:
"""
Reset the internal buffer.
"""

View File

@ -24,6 +24,7 @@ from ..mtproto import (
Update,
authentication,
)
from ..mtproto.mtp.types import DeserializationFailure
from ..tl import Request as RemoteCall
from ..tl.abcs import Updates
from ..tl.core import Serializable
@ -334,8 +335,12 @@ class Sender:
self._process_result(result)
elif isinstance(result, RpcError):
self._process_error(result)
else:
elif isinstance(result, BadMessageError):
self._process_bad_message(result)
elif isinstance(result, DeserializationFailure):
self._process_deserialize_error(result)
else:
raise RuntimeError(f"Unexpected result: {result}")
def _process_update(self, update: bytes | bytearray | memoryview) -> None:
try:
@ -424,6 +429,17 @@ class Sender:
result._caused_by = struct.unpack_from("<I", req.body)[0]
req.result.set_exception(result)
def _process_deserialize_error(self, failure: DeserializationFailure):
req = self._pop_request(failure.msg_id)
if req:
logging.debug(f"Got deserialization failure {failure.error}")
req.result.set_exception(failure.error)
else:
logging.info(
f"Got deserialization failure {failure.error} but no such request is saved"
)
def _pop_request(self, msg_id: MsgId) -> Optional[Request[object]]:
for req in self._requests:
if isinstance(req.state, Serialized) and req.state.msg_id == msg_id: