Protect against potential replay attacks

See #3753.
This commit is contained in:
Lonami Exo 2022-05-18 12:24:28 +02:00
parent 09b9cd8193
commit 184984ac51
2 changed files with 63 additions and 0 deletions

View File

@ -505,6 +505,8 @@ class MTProtoSender:
try: try:
message = self._state.decrypt_message_data(body) message = self._state.decrypt_message_data(body)
if message is None:
continue # this message is to be ignored
except TypeNotFoundError as e: except TypeNotFoundError as e:
# Received object which we don't know how to deserialize # Received object which we don't know how to deserialize
self._log.info('Type %08x not found, remaining data %r', self._log.info('Type %08x not found, remaining data %r',

View File

@ -2,6 +2,7 @@ import os
import struct import struct
import time import time
from hashlib import sha256 from hashlib import sha256
from collections import deque
from ..crypto import AES from ..crypto import AES
from ..errors import SecurityError, InvalidBufferError from ..errors import SecurityError, InvalidBufferError
@ -10,6 +11,17 @@ from ..tl.core import TLMessage
from ..tl.tlobject import TLRequest from ..tl.tlobject import TLRequest
from ..tl.functions import InvokeAfterMsgRequest from ..tl.functions import InvokeAfterMsgRequest
from ..tl.core.gzippacked import GzipPacked from ..tl.core.gzippacked import GzipPacked
from ..tl.types import BadServerSalt, BadMsgNotification
# N is not specified in https://core.telegram.org/mtproto/security_guidelines#checking-msg-id, but 500 is reasonable
MAX_RECENT_MSG_IDS = 500
MSG_TOO_NEW_DELTA = 30
MSG_TOO_OLD_DELTA = 300
# Something must be wrong if we ignore too many messages at the same time
MAX_CONSECUTIVE_IGNORED = 10
class _OpaqueRequest(TLRequest): class _OpaqueRequest(TLRequest):
@ -54,6 +66,9 @@ class MTProtoState:
self.salt = 0 self.salt = 0
self.id = self._sequence = self._last_msg_id = None self.id = self._sequence = self._last_msg_id = None
self._recent_remote_ids = deque(maxlen=MAX_RECENT_MSG_IDS)
self._highest_remote_id = 0
self._ignore_count = 0
self.reset() self.reset()
def reset(self): def reset(self):
@ -64,6 +79,9 @@ class MTProtoState:
self.id = struct.unpack('q', os.urandom(8))[0] self.id = struct.unpack('q', os.urandom(8))[0]
self._sequence = 0 self._sequence = 0
self._last_msg_id = 0 self._last_msg_id = 0
self._recent_remote_ids.clear()
self._highest_remote_id = 0
self._ignore_count = 0
def update_message_id(self, message): def update_message_id(self, message):
""" """
@ -134,6 +152,8 @@ class MTProtoState:
""" """
Inverse of `encrypt_message_data` for incoming server messages. Inverse of `encrypt_message_data` for incoming server messages.
""" """
now = time.time() + self.time_offset # get the time as early as possible, even if other checks make it go unused
if len(body) < 8: if len(body) < 8:
raise InvalidBufferError(body) raise InvalidBufferError(body)
@ -159,6 +179,16 @@ class MTProtoState:
raise SecurityError('Server replied with a wrong session ID') raise SecurityError('Server replied with a wrong session ID')
remote_msg_id = reader.read_long() remote_msg_id = reader.read_long()
if remote_msg_id % 2 != 1:
raise SecurityError('Server sent an even msg_id')
# Only perform the (somewhat expensive) check of duplicate if we did receive a lower ID
if remote_msg_id <= self._highest_remote_id and remote_msg_id in self._recent_remote_ids:
self._log.warning('Server resent the older message %d, ignoring', remote_msg_id)
self._count_ignored()
return None
remote_sequence = reader.read_int() remote_sequence = reader.read_int()
reader.read_int() # msg_len for the inner object, padding ignored reader.read_int() # msg_len for the inner object, padding ignored
@ -167,8 +197,39 @@ class MTProtoState:
# reader isn't used for anything else after this, it's unnecessary. # reader isn't used for anything else after this, it's unnecessary.
obj = reader.tgread_object() obj = reader.tgread_object()
# "Certain client-to-server service messages containing data sent by the client to the
# server (for example, msg_id of a recent client query) may, nonetheless, be processed
# on the client even if the time appears to be "incorrect". This is especially true of
# messages to change server_salt and notifications about invalid time on the client."
#
# This means we skip the time check for certain types of messages.
if obj.CONSTRUCTOR_ID not in (BadServerSalt.CONSTRUCTOR_ID, BadMsgNotification.CONSTRUCTOR_ID):
remote_msg_time = remote_msg_id >> 32
time_delta = now - remote_msg_time
if time_delta > MSG_TOO_OLD_DELTA:
self._log.warning('Server sent a very old message with ID %d, ignoring', remote_msg_id)
self._count_ignored()
return None
if -time_delta > MSG_TOO_NEW_DELTA:
self._log.warning('Server sent a very new message with ID %d, ignoring', remote_msg_id)
self._count_ignored()
return None
self._recent_remote_ids.append(remote_msg_id)
self._highest_remote_id = remote_msg_id
self._ignore_count = 0
return TLMessage(remote_msg_id, remote_sequence, obj) return TLMessage(remote_msg_id, remote_sequence, obj)
def _count_ignored(self):
# It's possible that ignoring a message "bricks" the connection,
# but this should not happen unless there's something else wrong.
self._ignore_count += 1
if self._ignore_count >= MAX_CONSECUTIVE_IGNORED:
raise SecurityError('Too many messages had to be ignored consecutively')
def _get_new_msg_id(self): def _get_new_msg_id(self):
""" """
Generates a new unique message ID based on the current Generates a new unique message ID based on the current