feat: implement Perfect Forward Secrecy for MTProtov2

This commit is contained in:
habcawa 2025-05-20 16:19:09 +02:00
parent c97e6023c4
commit 54cb16783f
9 changed files with 215 additions and 36 deletions

View File

@ -261,7 +261,8 @@ class TelegramBaseClient(abc.ABC):
base_logger: typing.Union[str, logging.Logger] = None, base_logger: typing.Union[str, logging.Logger] = None,
receive_updates: bool = True, receive_updates: bool = True,
catch_up: bool = False, catch_up: bool = False,
entity_cache_limit: int = 5000 entity_cache_limit: int = 5000,
store_tmp_auth_key_on_disk: bool = True
): ):
if not api_id or not api_hash: if not api_id or not api_hash:
raise ValueError( raise ValueError(
@ -287,7 +288,7 @@ class TelegramBaseClient(abc.ABC):
# Determine what session object we have # Determine what session object we have
if isinstance(session, (str, pathlib.Path)): if isinstance(session, (str, pathlib.Path)):
try: try:
session = SQLiteSession(str(session)) session = SQLiteSession(str(session), store_tmp_auth_key_on_disk=store_tmp_auth_key_on_disk)
except ImportError: except ImportError:
import warnings import warnings
warnings.warn( warnings.warn(
@ -459,6 +460,7 @@ class TelegramBaseClient(abc.ABC):
auto_reconnect=self._auto_reconnect, auto_reconnect=self._auto_reconnect,
connect_timeout=self._timeout, connect_timeout=self._timeout,
auth_key_callback=self._auth_key_callback, auth_key_callback=self._auth_key_callback,
tmp_auth_key_callback=self._tmp_auth_key_callback,
updates_queue=self._updates_queue, updates_queue=self._updates_queue,
auto_reconnect_callback=self._handle_auto_reconnect auto_reconnect_callback=self._handle_auto_reconnect
) )
@ -558,6 +560,7 @@ class TelegramBaseClient(abc.ABC):
return return
self.session.auth_key = self._sender.auth_key self.session.auth_key = self._sender.auth_key
self.session.tmp_auth_key = self._sender.tmp_auth_key
self.session.save() self.session.save()
try: try:
@ -770,6 +773,14 @@ class TelegramBaseClient(abc.ABC):
self.session.auth_key = auth_key self.session.auth_key = auth_key
self.session.save() self.session.save()
def _tmp_auth_key_callback(self: 'TelegramClient', tmp_auth_key):
"""
Callback from the sender whenever it needed to generate a
new authorization key. This means we are not authorized.
"""
self.session.tmp_auth_key = tmp_auth_key
self.session.save()
# endregion # endregion
# region Working with different connections/Data Centers # region Working with different connections/Data Centers

View File

@ -12,13 +12,16 @@ class AuthKey:
Represents an authorization key, used to encrypt and decrypt Represents an authorization key, used to encrypt and decrypt
messages sent to Telegram's data centers. messages sent to Telegram's data centers.
""" """
def __init__(self, data): def __init__(self, data, expires_at: int = -1):
""" """
Initializes a new authorization key. Initializes a new authorization key.
:param data: the data in bytes that represent this auth key. :param data: the data in bytes that represent this auth key.
:param expires_at: unix timestamp of key expiry
""" """
self.key = data self.key = data
self.expires_at = expires_at
self.tmp_key_bound = False
@property @property
def key(self): def key(self):

View File

@ -60,7 +60,8 @@ class MessagePacker:
if size <= MessageContainer.MAXIMUM_SIZE: if size <= MessageContainer.MAXIMUM_SIZE:
state.msg_id = self._state.write_data_as_message( state.msg_id = self._state.write_data_as_message(
buffer, state.data, isinstance(state.request, TLRequest), buffer, state.data, isinstance(state.request, TLRequest),
after_id=state.after.msg_id if state.after else None after_id=state.after.msg_id if state.after else None,
msg_id=state.msg_id
) )
batch.append(state) batch.append(state)
self._log.debug('Assigned msg_id = %d to %s (%x)', self._log.debug('Assigned msg_id = %d to %s (%x)',

View File

@ -7,7 +7,7 @@ import time
from hashlib import sha1 from hashlib import sha1
from ..tl.types import ( from ..tl.types import (
ResPQ, PQInnerData, ServerDHParamsFail, ServerDHParamsOk, ResPQ, PQInnerData, PQInnerDataTemp, ServerDHParamsFail, ServerDHParamsOk,
ServerDHInnerData, ClientDHInnerData, DhGenOk, DhGenRetry, DhGenFail ServerDHInnerData, ClientDHInnerData, DhGenOk, DhGenRetry, DhGenFail
) )
from .. import helpers from .. import helpers
@ -17,13 +17,15 @@ from ..extensions import BinaryReader
from ..tl.functions import ( from ..tl.functions import (
ReqPqMultiRequest, ReqDHParamsRequest, SetClientDHParamsRequest ReqPqMultiRequest, ReqDHParamsRequest, SetClientDHParamsRequest
) )
from ..tl.functions.auth import BindTempAuthKeyRequest
async def do_authentication(sender, tmp_auth_key=False, tmp_auth_key_expires_s = 24 * 3600):
async def do_authentication(sender):
""" """
Executes the authentication process with the Telegram servers. Executes the authentication process with the Telegram servers.
:param sender: a connected `MTProtoPlainSender`. :param sender: a connected `MTProtoPlainSender`.
:param tmp_auth_key: whether the flow for a tmp_auth_key should be executed
:param tmp_auth_key_expires_s: duration in s until tmp_auth_key expires
:return: returns a (authorization key, time offset) tuple. :return: returns a (authorization key, time offset) tuple.
""" """
# Step 1 sending: PQ Request, endianness doesn't matter since it's random # Step 1 sending: PQ Request, endianness doesn't matter since it's random
@ -41,6 +43,17 @@ async def do_authentication(sender):
p, q = rsa.get_byte_array(p), rsa.get_byte_array(q) p, q = rsa.get_byte_array(p), rsa.get_byte_array(q)
new_nonce = int.from_bytes(os.urandom(32), 'little', signed=True) new_nonce = int.from_bytes(os.urandom(32), 'little', signed=True)
if tmp_auth_key:
expires_at = int(time.time()) + tmp_auth_key_expires_s
pq_inner_data = bytes(PQInnerDataTemp(
pq=rsa.get_byte_array(pq), p=p, q=q,
nonce=res_pq.nonce,
server_nonce=res_pq.server_nonce,
new_nonce=new_nonce,
expires_in=tmp_auth_key_expires_s
))
else:
pq_inner_data = bytes(PQInnerData( pq_inner_data = bytes(PQInnerData(
pq=rsa.get_byte_array(pq), p=p, q=q, pq=rsa.get_byte_array(pq), p=p, q=q,
nonce=res_pq.nonce, nonce=res_pq.nonce,
@ -197,8 +210,33 @@ async def do_authentication(sender):
if not isinstance(dh_gen, DhGenOk): if not isinstance(dh_gen, DhGenOk):
raise AssertionError('Step 3.2 answer was %s' % dh_gen) raise AssertionError('Step 3.2 answer was %s' % dh_gen)
if tmp_auth_key:
# auth_key is only a tmp_auth_key here!
return auth_key, expires_at
else:
return auth_key, time_offset return auth_key, time_offset
async def bind_tmp_auth_key(sender, auth_key, tmp_auth_key):
"""
Binds a tmpAuthKey to an authkey.
:param sender: a connected `MTProtoender`.
:param auth_key: auth key to bind to
:parm tmp_auth_key: unbound tmp_auth_key
:return: None
"""
nonce = int.from_bytes(os.urandom(8), 'little', signed=True)
timestamp = tmp_auth_key.expires_at
encrypted_bind, msg_id = sender._state.get_encrypted_bind(nonce, auth_key, tmp_auth_key, timestamp)
bind_request = BindTempAuthKeyRequest(
perm_auth_key_id=auth_key.key_id,
nonce=nonce,
expires_at=timestamp,
encrypted_message=encrypted_bind
)
await sender.send(bind_request, msg_id=msg_id)
def get_int(byte_array, signed=True): def get_int(byte_array, signed=True):
""" """

View File

@ -47,7 +47,7 @@ class MTProtoSender:
""" """
def __init__(self, auth_key, *, loggers, def __init__(self, auth_key, *, loggers,
retries=5, delay=1, auto_reconnect=True, connect_timeout=None, retries=5, delay=1, auto_reconnect=True, connect_timeout=None,
auth_key_callback=None, auth_key_callback=None, tmp_auth_key=None, tmp_auth_key_callback=None,
updates_queue=None, auto_reconnect_callback=None): updates_queue=None, auto_reconnect_callback=None):
self._connection = None self._connection = None
self._loggers = loggers self._loggers = loggers
@ -57,6 +57,7 @@ class MTProtoSender:
self._auto_reconnect = auto_reconnect self._auto_reconnect = auto_reconnect
self._connect_timeout = connect_timeout self._connect_timeout = connect_timeout
self._auth_key_callback = auth_key_callback self._auth_key_callback = auth_key_callback
self._tmp_auth_key_callback = tmp_auth_key_callback
self._updates_queue = updates_queue self._updates_queue = updates_queue
self._auto_reconnect_callback = auto_reconnect_callback self._auto_reconnect_callback = auto_reconnect_callback
self._connect_lock = asyncio.Lock() self._connect_lock = asyncio.Lock()
@ -79,7 +80,8 @@ class MTProtoSender:
# Preserving the references of the AuthKey and state is important # Preserving the references of the AuthKey and state is important
self.auth_key = auth_key or AuthKey(None) self.auth_key = auth_key or AuthKey(None)
self._state = MTProtoState(self.auth_key, loggers=self._loggers) self.tmp_auth_key = tmp_auth_key or AuthKey(None)
self._state = MTProtoState(self.tmp_auth_key, loggers=self._loggers)
# Outgoing messages are put in a queue and sent in a batch. # Outgoing messages are put in a queue and sent in a batch.
# Note that here we're also storing their ``_RequestState``. # Note that here we're also storing their ``_RequestState``.
@ -152,7 +154,7 @@ class MTProtoSender:
""" """
await self._disconnect() await self._disconnect()
def send(self, request, ordered=False): def send(self, request, ordered=False, msg_id=None):
""" """
This method enqueues the given request to be sent. Its send This method enqueues the given request to be sent. Its send
state will be saved until a response arrives, and a ``Future`` state will be saved until a response arrives, and a ``Future``
@ -180,7 +182,7 @@ class MTProtoSender:
if not utils.is_list_like(request): if not utils.is_list_like(request):
try: try:
state = RequestState(request) state = RequestState(request, msg_id=msg_id)
except struct.error as e: except struct.error as e:
# "struct.error: required argument is not an integer" is not # "struct.error: required argument is not an integer" is not
# very helpful; log the request to find out what wasn't int. # very helpful; log the request to find out what wasn't int.
@ -254,6 +256,25 @@ class MTProtoSender:
await asyncio.sleep(self._delay) await asyncio.sleep(self._delay)
continue # next iteration we will try to reconnect continue # next iteration we will try to reconnect
if not self.tmp_auth_key:
# establish tmp_auth_key here, but make sure to bind to the auth_key later
try:
if not await self._try_gen_tmp_auth_key(attempt):
continue # keep retrying until we have the tmp auth key
except (IOError, asyncio.TimeoutError) as e:
# Sometimes, specially during user-DC migrations,
# Telegram may close the connection during auth_key
# generation. If that's the case, we will need to
# connect again.
self._log.warning('Connection error %d during tmp_auth_key gen: %s: %s',
attempt, type(e).__name__, e)
# Whatever the IOError was, make sure to disconnect so we can
# reconnect cleanly after.
await self._connection.disconnect()
connected = False
await asyncio.sleep(self._delay)
continue # next iteration we will try to reconnect
break # all steps done, break retry loop break # all steps done, break retry loop
else: else:
if not connected: if not connected:
@ -264,15 +285,21 @@ class MTProtoSender:
raise e raise e
loop = helpers.get_running_loop() loop = helpers.get_running_loop()
# dirty hack, but otherwise we cannot send the binding of the tmp_auth_key
self._user_connected = True
# update key, this was unavailable at init of self._state
self._state.auth_key = self.tmp_auth_key
self._log.debug('Starting send loop') self._log.debug('Starting send loop')
self._send_loop_handle = loop.create_task(self._send_loop()) self._send_loop_handle = loop.create_task(self._send_loop())
self._log.debug('Starting receive loop') self._log.debug('Starting receive loop')
self._recv_loop_handle = loop.create_task(self._recv_loop()) self._recv_loop_handle = loop.create_task(self._recv_loop())
# _disconnected only completes after manual disconnection # both self.auth_key and self.tmp_auth_key are required for the binding
# or errors after which the sender cannot continue such # and it can only take place after the send/recv loops as the
# as failing to reconnect or any unexpected error. # the binding message needs to be sent encrypted
await authenticator.bind_tmp_auth_key(self, self.auth_key, self.tmp_auth_key)
if self._disconnected.done(): if self._disconnected.done():
self._disconnected = loop.create_future() self._disconnected = loop.create_future()
@ -290,6 +317,27 @@ class MTProtoSender:
await asyncio.sleep(self._delay) await asyncio.sleep(self._delay)
return False return False
async def _try_gen_tmp_auth_key(self, attempt):
plain = MTProtoPlainSender(self._connection, loggers=self._loggers)
try:
self._log.debug('New tmp_auth_key attempt %d...', attempt)
self.tmp_auth_key.key, self.tmp_auth_key.expires_at = \
await authenticator.do_authentication(plain, tmp_auth_key=True)
# This is *EXTREMELY* important since we don't control
# external references to the temporary authorization key, we must
# notify whenever we change it. This is crucial when we
# switch to different data centers.
if self._tmp_auth_key_callback:
self._tmp_auth_key_callback(self.tmp_auth_key)
self._log.info('tmp_auth_key generation success!')
return True
except (SecurityError, AssertionError) as e:
self._log.warning('Attempt %d at new tmp_auth_key failed: %s', attempt, e)
await asyncio.sleep(self._delay)
return False
async def _try_gen_auth_key(self, attempt): async def _try_gen_auth_key(self, attempt):
plain = MTProtoPlainSender(self._connection, loggers=self._loggers) plain = MTProtoPlainSender(self._connection, loggers=self._loggers)
try: try:
@ -366,7 +414,10 @@ class MTProtoSender:
self._reconnecting = False self._reconnecting = False
# Start with a clean state (and thus session ID) to avoid old msgs # Start with a clean state (and thus session ID) to avoid old msgs
self._state.reset() self._state.reset(keep_key=False)
self.tmp_auth_key = AuthKey(None)
if self._tmp_auth_key_callback:
self._tmp_auth_key_callback(self.tmp_auth_key)
retries = self._retries if self._auto_reconnect else 0 retries = self._retries if self._auto_reconnect else 0

View File

@ -1,7 +1,7 @@
import os import os
import struct import struct
import time import time
from hashlib import sha256 from hashlib import sha1, sha256
from collections import deque from collections import deque
from ..crypto import AES from ..crypto import AES
@ -11,8 +11,7 @@ from ..tl.core import TLMessage
from ..tl.tlobject import TLRequest from ..tl.tlobject import TLRequest
from ..tl.functions import InvokeAfterMsgRequest from ..tl.functions import InvokeAfterMsgRequest
from ..tl.core.gzippacked import GzipPacked from ..tl.core.gzippacked import GzipPacked
from ..tl.types import BadServerSalt, BadMsgNotification from ..tl.types import BadServerSalt, BadMsgNotification, BindAuthKeyInner
# N is not specified in https://core.telegram.org/mtproto/security_guidelines#checking-msg-id, but 500 is reasonable # N is not specified in https://core.telegram.org/mtproto/security_guidelines#checking-msg-id, but 500 is reasonable
MAX_RECENT_MSG_IDS = 500 MAX_RECENT_MSG_IDS = 500
@ -106,13 +105,37 @@ class MTProtoState:
return aes_key, aes_iv return aes_key, aes_iv
@staticmethod
def _calc_key_v1(auth_key, msg_key, client):
"""
Calculate the key based on Telegram guidelines for MTProto 1,
specifying whether it's the client or not. See
https://core.telegram.org/mtproto/description_v1#defining-aes-key-and-initialization-vector
"""
x = 0 if client else 8
sha1a = sha1(msg_key + auth_key[x:x+32]).digest()
sha1b = sha1(auth_key[x+32:x+48] + msg_key + auth_key[x+48:x+64]).digest()
sha1c = sha1(auth_key[x+64:x+96] + msg_key).digest()
sha1d = sha1(msg_key + auth_key[x+96:x+128]).digest()
aes_key = sha1a[0:8] + sha1b[8:20] + sha1c[4:16]
aes_iv = sha1a[8:20] + sha1b[0:8] + sha1c[16:20] + sha1d[0:8]
return aes_key, aes_iv
def write_data_as_message(self, buffer, data, content_related, def write_data_as_message(self, buffer, data, content_related,
*, after_id=None): *, after_id=None, msg_id=None):
""" """
Writes a message containing the given data into buffer. Writes a message containing the given data into buffer.
Returns the message id. Returns the message id.
""" """
if msg_id is None:
# this should be the default - the binding of a tmpAuthKey is the only
# exception, as the msg_id of the encrypted BindAuthKeyInner needs to be equal to the
# msg_id of the outer message
# see: https://core.telegram.org/method/auth.bindTempAuthKey#encrypting-the-binding-message
msg_id = self._get_new_msg_id() msg_id = self._get_new_msg_id()
seq_no = self._get_seq_no(content_related) seq_no = self._get_seq_no(content_related)
if after_id is None: if after_id is None:
@ -127,6 +150,32 @@ class MTProtoState:
buffer.write(body) buffer.write(body)
return msg_id return msg_id
def get_encrypted_bind(self, nonce, auth_key, tmp_auth_key, timestamp):
# strangely, this should be encrypted using MTProto1
# see https://core.telegram.org/method/auth.bindTempAuthKey#encrypting-the-binding-message
msg_id = self._get_new_msg_id()
seq_no = 0
bind = BindAuthKeyInner(
nonce=nonce,
temp_auth_key_id=tmp_auth_key.key_id,
perm_auth_key_id=auth_key.key_id,
temp_session_id=self.id,
expires_at=timestamp
)
bind = bytes(bind)
assert len(bind) == 40
payload = os.urandom(int(128/8)) + struct.pack('<qii', msg_id, seq_no, len(bind)) + bind
padding = os.urandom(len(payload) % 16)
msg_key = sha1(payload).digest()[4:20]
aes_key, aes_iv = self._calc_key_v1(auth_key.key, msg_key, True)
key_id = struct.pack('<q', auth_key.key_id)
crypt = AES.encrypt_ige(payload + padding, aes_key, aes_iv)
encrypted_message = (key_id + msg_key + crypt)
return encrypted_message, msg_id
def encrypt_message_data(self, data): def encrypt_message_data(self, data):
""" """
Encrypts the given message data using the current authorization key Encrypts the given message data using the current authorization key

View File

@ -10,9 +10,9 @@ class RequestState:
""" """
__slots__ = ('container_id', 'msg_id', 'request', 'data', 'future', 'after') __slots__ = ('container_id', 'msg_id', 'request', 'data', 'future', 'after')
def __init__(self, request, after=None): def __init__(self, request, after=None, msg_id=None):
self.container_id = None self.container_id = None
self.msg_id = None self.msg_id = msg_id
self.request = request self.request = request
self.data = bytes(request) self.data = bytes(request)
self.future = asyncio.Future() self.future = asyncio.Future()

View File

@ -32,6 +32,7 @@ class MemorySession(Session):
self._server_address = None self._server_address = None
self._port = None self._port = None
self._auth_key = None self._auth_key = None
self._tmp_auth_key = None
self._takeout_id = None self._takeout_id = None
self._files = {} self._files = {}
@ -59,10 +60,18 @@ class MemorySession(Session):
def auth_key(self): def auth_key(self):
return self._auth_key return self._auth_key
@property
def tmp_auth_key(self):
return self._tmp_auth_key
@auth_key.setter @auth_key.setter
def auth_key(self, value): def auth_key(self, value):
self._auth_key = value self._auth_key = value
@auth_key.setter
def tmp_auth_key(self, value):
self._tmp_auth_key = value
@property @property
def takeout_id(self): def takeout_id(self):
return self._takeout_id return self._takeout_id

View File

@ -18,7 +18,7 @@ except ImportError as e:
sqlite3_err = type(e) sqlite3_err = type(e)
EXTENSION = '.session' EXTENSION = '.session'
CURRENT_VERSION = 7 # database version CURRENT_VERSION = 8 # database version
class SQLiteSession(MemorySession): class SQLiteSession(MemorySession):
@ -30,13 +30,14 @@ class SQLiteSession(MemorySession):
through an official Telegram client to revoke the authorization. through an official Telegram client to revoke the authorization.
""" """
def __init__(self, session_id=None): def __init__(self, session_id=None, store_tmp_auth_key_on_disk:bool=False):
if sqlite3 is None: if sqlite3 is None:
raise sqlite3_err raise sqlite3_err
super().__init__() super().__init__()
self.filename = ':memory:' self.filename = ':memory:'
self.save_entities = True self.save_entities = True
self.store_tmp_auth_key_on_disk = store_tmp_auth_key_on_disk
if session_id: if session_id:
self.filename = session_id self.filename = session_id
@ -61,9 +62,10 @@ class SQLiteSession(MemorySession):
c.execute('select * from sessions') c.execute('select * from sessions')
tuple_ = c.fetchone() tuple_ = c.fetchone()
if tuple_: if tuple_:
self._dc_id, self._server_address, self._port, key, \ self._dc_id, self._server_address, self._port, key, tmp_key, \
self._takeout_id = tuple_ self._takeout_id = tuple_
self._auth_key = AuthKey(data=key) self._auth_key = AuthKey(data=key)
self._tmp_auth_key = AuthKey(data=tmp_key)
c.close() c.close()
else: else:
@ -77,7 +79,8 @@ class SQLiteSession(MemorySession):
server_address text, server_address text,
port integer, port integer,
auth_key blob, auth_key blob,
takeout_id integer takeout_id integer,
tmp_auth_key blob
)""" )"""
, ,
"""entities ( """entities (
@ -153,6 +156,9 @@ class SQLiteSession(MemorySession):
if old == 6: if old == 6:
old += 1 old += 1
c.execute("alter table entities add column date integer") c.execute("alter table entities add column date integer")
if old == 7:
old += 1
c.execute("alter table sessions add column tmp_auth_key blob")
c.close() c.close()
@ -168,17 +174,27 @@ class SQLiteSession(MemorySession):
self._update_session_table() self._update_session_table()
# Fetch the auth_key corresponding to this data center # Fetch the auth_key corresponding to this data center
row = self._execute('select auth_key from sessions') row = self._execute('select auth_key, tmp_auth_key from sessions')
if row and row[0]: if row and row[0]:
self._auth_key = AuthKey(data=row[0]) self._auth_key = AuthKey(data=row[0])
else: else:
self._auth_key = None self._auth_key = None
if row and row[1]:
self._tmp_auth_key = AuthKey(data=row[1])
else:
self._tmp_auth_key = None
@MemorySession.auth_key.setter @MemorySession.auth_key.setter
def auth_key(self, value): def auth_key(self, value):
self._auth_key = value self._auth_key = value
self._update_session_table() self._update_session_table()
@MemorySession.tmp_auth_key.setter
def tmp_auth_key(self, value):
self._tmp_auth_key = value
self._update_session_table()
@MemorySession.takeout_id.setter @MemorySession.takeout_id.setter
def takeout_id(self, value): def takeout_id(self, value):
self._takeout_id = value self._takeout_id = value
@ -192,12 +208,13 @@ class SQLiteSession(MemorySession):
# some more work before being able to save auth_key's for # some more work before being able to save auth_key's for
# multiple DCs. Probably done differently. # multiple DCs. Probably done differently.
c.execute('delete from sessions') c.execute('delete from sessions')
c.execute('insert or replace into sessions values (?,?,?,?,?)', ( c.execute('insert or replace into sessions values (?,?,?,?,?,?)', (
self._dc_id, self._dc_id,
self._server_address, self._server_address,
self._port, self._port,
self._auth_key.key if self._auth_key else b'', self._auth_key.key if self._auth_key else b'',
self._takeout_id self._takeout_id,
self._tmp_auth_key.key if (self.store_tmp_auth_key_on_disk and self._tmp_auth_key) else b''
)) ))
c.close() c.close()