mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-06-17 10:36:37 +00:00
Refactor error handling in crypto and mtproto modules; introduce custom error classes and improve deserialization error processing in Sender class
This commit is contained in:
parent
13c99d565d
commit
1c2aafce2a
@ -7,6 +7,24 @@ from .aes import ige_decrypt, ige_encrypt
|
||||
from .auth_key import AuthKey
|
||||
|
||||
|
||||
class InvalidBufferError(ValueError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Invalid ciphertext buffer length")
|
||||
|
||||
|
||||
class AuthKeyMismatchError(ValueError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Server authkey mismatches with ours")
|
||||
|
||||
|
||||
class MsgKeyMismatchError(ValueError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Server msgkey mismatches with ours")
|
||||
|
||||
|
||||
CryptoError = InvalidBufferError | AuthKeyMismatchError | MsgKeyMismatchError
|
||||
|
||||
|
||||
# "where x = 0 for messages from client to server and x = 8 for those from server to client"
|
||||
class Side(IntEnum):
|
||||
CLIENT = 0
|
||||
@ -77,14 +95,14 @@ def decrypt_data_v2(
|
||||
x = int(side)
|
||||
|
||||
if len(ciphertext) < 24 or (len(ciphertext) - 24) % 16 != 0:
|
||||
raise ValueError("invalid ciphertext buffer length")
|
||||
raise InvalidBufferError()
|
||||
|
||||
# salt, session_id and sequence_number should also be checked.
|
||||
# However, not doing so has worked fine for years.
|
||||
|
||||
key_id = ciphertext[:8]
|
||||
if auth_key.key_id != key_id:
|
||||
raise ValueError("server authkey mismatches with ours")
|
||||
raise AuthKeyMismatchError()
|
||||
|
||||
msg_key = ciphertext[8:24]
|
||||
key, iv = calc_key(auth_key, msg_key, side)
|
||||
@ -93,7 +111,7 @@ def decrypt_data_v2(
|
||||
# https://core.telegram.org/mtproto/security_guidelines#mtproto-encrypted-messages
|
||||
our_key = sha256(auth_key.data[x + 88 : x + 120] + plaintext).digest()
|
||||
if msg_key != our_key[8:24]:
|
||||
raise ValueError("server msgkey mismatches with ours")
|
||||
raise MsgKeyMismatchError()
|
||||
|
||||
return plaintext
|
||||
|
||||
|
@ -62,11 +62,15 @@ from ..utils import (
|
||||
)
|
||||
from .types import (
|
||||
BadMessageError,
|
||||
DecompressionFailed,
|
||||
Deserialization,
|
||||
DeserializationFailure,
|
||||
MsgBufferTooSmall,
|
||||
MsgId,
|
||||
Mtp,
|
||||
RpcError,
|
||||
RpcResult,
|
||||
UnexpectedConstructor,
|
||||
Update,
|
||||
)
|
||||
|
||||
@ -85,6 +89,7 @@ UPDATE_IDS = {
|
||||
AffectedFoundMessages.constructor_id(),
|
||||
AffectedHistory.constructor_id(),
|
||||
AffectedMessages.constructor_id(),
|
||||
# TODO InvitedUsers
|
||||
}
|
||||
|
||||
HEADER_LEN = 8 + 8 # salt, client_id
|
||||
@ -151,7 +156,7 @@ class Encrypted(Mtp):
|
||||
self._last_msg_id: int
|
||||
self._in_pending_ack: list[int] = []
|
||||
self._msg_count: int
|
||||
self._reset_session()
|
||||
self.reset()
|
||||
|
||||
@property
|
||||
def auth_key(self) -> bytes:
|
||||
@ -166,13 +171,6 @@ class Encrypted(Mtp):
|
||||
def _adjusted_now(self) -> float:
|
||||
return time.time() + self._time_offset
|
||||
|
||||
def _reset_session(self) -> None:
|
||||
self._client_id = struct.unpack("<q", os.urandom(8))[0]
|
||||
self._sequence = 0
|
||||
self._last_msg_id = 0
|
||||
self._in_pending_ack.clear()
|
||||
self._msg_count = 0
|
||||
|
||||
def _get_new_msg_id(self) -> int:
|
||||
new_msg_id = int(self._adjusted_now() * 0x100000000)
|
||||
if self._last_msg_id >= new_msg_id:
|
||||
@ -245,12 +243,38 @@ class Encrypted(Mtp):
|
||||
result = rpc_result.result
|
||||
|
||||
msg_id = MsgId(req_msg_id)
|
||||
inner_constructor = struct.unpack_from("<I", result)[0]
|
||||
|
||||
try:
|
||||
inner_constructor = struct.unpack_from("<I", result)[0]
|
||||
except struct.error as e:
|
||||
# If the result is empty, we can't unpack it.
|
||||
# This can happen if the server returns an empty response.
|
||||
logging.exception(e)
|
||||
self._deserialization.append(
|
||||
DeserializationFailure(
|
||||
msg_id=msg_id,
|
||||
error=MsgBufferTooSmall(),
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if inner_constructor == GeneratedRpcError.constructor_id():
|
||||
error = RpcError._from_mtproto_error(GeneratedRpcError.from_bytes(result))
|
||||
error.msg_id = msg_id
|
||||
self._deserialization.append(error)
|
||||
try:
|
||||
error = RpcError._from_mtproto_error(
|
||||
GeneratedRpcError.from_bytes(result)
|
||||
)
|
||||
error.msg_id = msg_id
|
||||
self._deserialization.append(error)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
self._deserialization.append(
|
||||
DeserializationFailure(
|
||||
msg_id=msg_id,
|
||||
error=UnexpectedConstructor(
|
||||
id=inner_constructor,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif inner_constructor == RpcAnswerUnknown.constructor_id():
|
||||
pass # msg_id = rpc_drop_answer.msg_id
|
||||
elif inner_constructor == RpcAnswerDroppedRunning.constructor_id():
|
||||
@ -258,9 +282,15 @@ class Encrypted(Mtp):
|
||||
elif inner_constructor == RpcAnswerDropped.constructor_id():
|
||||
pass # dropped
|
||||
elif inner_constructor == GzipPacked.constructor_id():
|
||||
body = gzip_decompress(GzipPacked.from_bytes(result))
|
||||
self._store_own_updates(body)
|
||||
self._deserialization.append(RpcResult(msg_id, body))
|
||||
try:
|
||||
body = gzip_decompress(GzipPacked.from_bytes(result))
|
||||
self._store_own_updates(body)
|
||||
self._deserialization.append(RpcResult(msg_id, body))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
self._deserialization.append(
|
||||
DeserializationFailure(msg_id=msg_id, error=DecompressionFailed())
|
||||
)
|
||||
else:
|
||||
self._store_own_updates(result)
|
||||
self._deserialization.append(RpcResult(msg_id, result))
|
||||
@ -300,7 +330,7 @@ class Encrypted(Mtp):
|
||||
elif bad_msg.error_code in (16, 17):
|
||||
self._correct_time_offset(message.msg_id)
|
||||
elif bad_msg.error_code in (32, 33):
|
||||
self._reset_session()
|
||||
self.reset()
|
||||
else:
|
||||
raise exc
|
||||
|
||||
@ -365,6 +395,9 @@ class Encrypted(Mtp):
|
||||
for inner_message in container.messages:
|
||||
self._process_message(inner_message)
|
||||
|
||||
def _handle_msg_copy(self, message: Message) -> None:
|
||||
raise RuntimeError("msg_copy should not be used")
|
||||
|
||||
def _handle_gzip_packed(self, message: Message) -> None:
|
||||
container = GzipPacked.from_bytes(message.body)
|
||||
inner_body = gzip_decompress(container)
|
||||
@ -459,3 +492,11 @@ class Encrypted(Mtp):
|
||||
result = self._deserialization[:]
|
||||
self._deserialization.clear()
|
||||
return result
|
||||
|
||||
def reset(self) -> None:
|
||||
self._client_id = struct.unpack("<q", os.urandom(8))[0]
|
||||
self._sequence = 0
|
||||
self._last_msg_id = 0
|
||||
self._in_pending_ack.clear()
|
||||
self._msg_count = 0
|
||||
self._salt_request_msg_id = None
|
||||
|
@ -47,9 +47,10 @@ class Plain(Mtp):
|
||||
if length < 0:
|
||||
raise ValueError(f"bad length: expected >= 0, got: {length}")
|
||||
|
||||
if 20 + length > len(payload):
|
||||
raise ValueError(
|
||||
f"message too short, expected: {20 + length}, got {len(payload)}"
|
||||
)
|
||||
if 20 + length > (lp := len(payload)):
|
||||
raise ValueError(f"message too short, expected: {20 + length}, got {lp}")
|
||||
|
||||
return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))]
|
||||
|
||||
def reset(self) -> None:
|
||||
self._buffer.clear()
|
||||
|
@ -5,6 +5,7 @@ from typing import NewType, Optional
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from ...crypto.crypto import CryptoError
|
||||
from ...tl.mtproto.types import RpcError as GeneratedRpcError
|
||||
|
||||
MsgId = NewType("MsgId", int)
|
||||
@ -180,7 +181,105 @@ class BadMessageError(ValueError):
|
||||
return self._code == other._code
|
||||
|
||||
|
||||
Deserialization = Update | RpcResult | RpcError | BadMessageError
|
||||
DeserializationError = ValueError
|
||||
|
||||
|
||||
class DeserializationFailure:
|
||||
__slots__ = ("msg_id", "error")
|
||||
|
||||
def __init__(self, msg_id: MsgId, error: DeserializationError) -> None:
|
||||
self.msg_id = msg_id
|
||||
self.error = error
|
||||
|
||||
|
||||
Deserialization = (
|
||||
Update | RpcResult | RpcError | BadMessageError | DeserializationFailure
|
||||
)
|
||||
|
||||
|
||||
# Deserialization errors are not fatal, so we don't subclass RpcError.
|
||||
class BadAuthKeyError(DeserializationError):
|
||||
def __init__(self, *args: object, got: int, expected: int) -> None:
|
||||
super().__init__(f"Bad server auth key (got {got}, expected {expected})", *args)
|
||||
self._got = got
|
||||
self._expected = expected
|
||||
|
||||
@property
|
||||
def got(self):
|
||||
return self._got
|
||||
|
||||
@property
|
||||
def expected(self):
|
||||
return self._expected
|
||||
|
||||
|
||||
class BadMsgIdError(DeserializationError):
|
||||
def __init__(self, *args: object, got: int) -> None:
|
||||
super().__init__(f"Bad server message id (got {got})", *args)
|
||||
self._got = got
|
||||
|
||||
@property
|
||||
def got(self):
|
||||
return self._got
|
||||
|
||||
|
||||
class NegativeLengthError(DeserializationError):
|
||||
def __init__(self, *args: object, got: int) -> None:
|
||||
super().__init__(f"Bad server message length (got {got})", *args)
|
||||
self._got = got
|
||||
|
||||
@property
|
||||
def got(self):
|
||||
return self._got
|
||||
|
||||
|
||||
class TooLongMsgError(DeserializationError):
|
||||
__slots__ = ("expected", "got")
|
||||
|
||||
def __init__(self, *args: object, got: int, max_length: int) -> None:
|
||||
super().__init__(
|
||||
f"Bad server message length (got {got}, when at most it should be {max_length})",
|
||||
*args,
|
||||
)
|
||||
self._got = got
|
||||
self._expected = max_length
|
||||
|
||||
@property
|
||||
def got(self):
|
||||
return self._got
|
||||
|
||||
@property
|
||||
def expected(self):
|
||||
return self._expected
|
||||
|
||||
|
||||
class MsgBufferTooSmall(DeserializationError):
|
||||
def __init__(self, *args: object) -> None:
|
||||
super().__init__(
|
||||
"Server responded with a payload that's too small to fit a valid message",
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
class DecompressionFailed(DeserializationError):
|
||||
def __init__(self, *args: object) -> None:
|
||||
super().__init__("Failed to decompress server's data", *args)
|
||||
|
||||
|
||||
class UnexpectedConstructor(DeserializationError):
|
||||
def __init__(self, *args: object, id: int) -> None:
|
||||
super().__init__(f"Unexpected constructor: {id:08x}", *args)
|
||||
|
||||
|
||||
class DecryptionError(DeserializationError):
|
||||
def __init__(self, *args: object, error: CryptoError) -> None:
|
||||
super().__init__(f"failed to decrypt message: {error}", *args)
|
||||
|
||||
self._error = error
|
||||
|
||||
@property
|
||||
def error(self):
|
||||
return self._error
|
||||
|
||||
|
||||
# https://core.telegram.org/mtproto/description
|
||||
@ -209,3 +308,9 @@ class Mtp(ABC):
|
||||
"""
|
||||
Deserialize incoming buffer payload.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset the internal buffer.
|
||||
"""
|
||||
|
0
client/src/telethon/_impl/mtsender/errors.py
Normal file
0
client/src/telethon/_impl/mtsender/errors.py
Normal file
@ -24,6 +24,7 @@ from ..mtproto import (
|
||||
Update,
|
||||
authentication,
|
||||
)
|
||||
from ..mtproto.mtp.types import DeserializationFailure
|
||||
from ..tl import Request as RemoteCall
|
||||
from ..tl.abcs import Updates
|
||||
from ..tl.core import Serializable
|
||||
@ -334,8 +335,12 @@ class Sender:
|
||||
self._process_result(result)
|
||||
elif isinstance(result, RpcError):
|
||||
self._process_error(result)
|
||||
else:
|
||||
elif isinstance(result, BadMessageError):
|
||||
self._process_bad_message(result)
|
||||
elif isinstance(result, DeserializationFailure):
|
||||
self._process_deserialize_error(result)
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected result: {result}")
|
||||
|
||||
def _process_update(self, update: bytes | bytearray | memoryview) -> None:
|
||||
try:
|
||||
@ -424,6 +429,17 @@ class Sender:
|
||||
result._caused_by = struct.unpack_from("<I", req.body)[0]
|
||||
req.result.set_exception(result)
|
||||
|
||||
def _process_deserialize_error(self, failure: DeserializationFailure):
|
||||
req = self._pop_request(failure.msg_id)
|
||||
|
||||
if req:
|
||||
logging.debug(f"Got deserialization failure {failure.error}")
|
||||
req.result.set_exception(failure.error)
|
||||
else:
|
||||
logging.info(
|
||||
f"Got deserialization failure {failure.error} but no such request is saved"
|
||||
)
|
||||
|
||||
def _pop_request(self, msg_id: MsgId) -> Optional[Request[object]]:
|
||||
for req in self._requests:
|
||||
if isinstance(req.state, Serialized) and req.state.msg_id == msg_id:
|
||||
|
Loading…
Reference in New Issue
Block a user