feat: implement Perfect Forward Secrecy for MTProtov2

This commit is contained in:
habcawa 2025-05-20 16:19:09 +02:00
parent c97e6023c4
commit 54cb16783f
9 changed files with 215 additions and 36 deletions

View File

@ -261,7 +261,8 @@ class TelegramBaseClient(abc.ABC):
base_logger: typing.Union[str, logging.Logger] = None,
receive_updates: bool = True,
catch_up: bool = False,
entity_cache_limit: int = 5000
entity_cache_limit: int = 5000,
store_tmp_auth_key_on_disk: bool = True
):
if not api_id or not api_hash:
raise ValueError(
@ -287,7 +288,7 @@ class TelegramBaseClient(abc.ABC):
# Determine what session object we have
if isinstance(session, (str, pathlib.Path)):
try:
session = SQLiteSession(str(session))
session = SQLiteSession(str(session), store_tmp_auth_key_on_disk=store_tmp_auth_key_on_disk)
except ImportError:
import warnings
warnings.warn(
@ -459,6 +460,7 @@ class TelegramBaseClient(abc.ABC):
auto_reconnect=self._auto_reconnect,
connect_timeout=self._timeout,
auth_key_callback=self._auth_key_callback,
tmp_auth_key_callback=self._tmp_auth_key_callback,
updates_queue=self._updates_queue,
auto_reconnect_callback=self._handle_auto_reconnect
)
@ -558,6 +560,7 @@ class TelegramBaseClient(abc.ABC):
return
self.session.auth_key = self._sender.auth_key
self.session.tmp_auth_key = self._sender.tmp_auth_key
self.session.save()
try:
@ -770,6 +773,14 @@ class TelegramBaseClient(abc.ABC):
self.session.auth_key = auth_key
self.session.save()
def _tmp_auth_key_callback(self: 'TelegramClient', tmp_auth_key):
"""
Callback from the sender whenever it needed to generate a
new authorization key. This means we are not authorized.
"""
self.session.tmp_auth_key = tmp_auth_key
self.session.save()
# endregion
# region Working with different connections/Data Centers

View File

@ -12,13 +12,16 @@ class AuthKey:
Represents an authorization key, used to encrypt and decrypt
messages sent to Telegram's data centers.
"""
def __init__(self, data):
def __init__(self, data, expires_at: int = -1):
"""
Initializes a new authorization key.
:param data: the data in bytes that represent this auth key.
:param expires_at: unix timestamp of key expiry
"""
self.key = data
self.expires_at = expires_at
self.tmp_key_bound = False
@property
def key(self):

View File

@ -60,7 +60,8 @@ class MessagePacker:
if size <= MessageContainer.MAXIMUM_SIZE:
state.msg_id = self._state.write_data_as_message(
buffer, state.data, isinstance(state.request, TLRequest),
after_id=state.after.msg_id if state.after else None
after_id=state.after.msg_id if state.after else None,
msg_id=state.msg_id
)
batch.append(state)
self._log.debug('Assigned msg_id = %d to %s (%x)',

View File

@ -7,7 +7,7 @@ import time
from hashlib import sha1
from ..tl.types import (
ResPQ, PQInnerData, ServerDHParamsFail, ServerDHParamsOk,
ResPQ, PQInnerData, PQInnerDataTemp, ServerDHParamsFail, ServerDHParamsOk,
ServerDHInnerData, ClientDHInnerData, DhGenOk, DhGenRetry, DhGenFail
)
from .. import helpers
@ -17,13 +17,15 @@ from ..extensions import BinaryReader
from ..tl.functions import (
ReqPqMultiRequest, ReqDHParamsRequest, SetClientDHParamsRequest
)
from ..tl.functions.auth import BindTempAuthKeyRequest
async def do_authentication(sender):
async def do_authentication(sender, tmp_auth_key=False, tmp_auth_key_expires_s = 24 * 3600):
"""
Executes the authentication process with the Telegram servers.
:param sender: a connected `MTProtoPlainSender`.
:param tmp_auth_key: whether the flow for a tmp_auth_key should be executed
:param tmp_auth_key_expires_s: duration in s until tmp_auth_key expires
:return: returns a (authorization key, time offset) tuple.
"""
# Step 1 sending: PQ Request, endianness doesn't matter since it's random
@ -41,6 +43,17 @@ async def do_authentication(sender):
p, q = rsa.get_byte_array(p), rsa.get_byte_array(q)
new_nonce = int.from_bytes(os.urandom(32), 'little', signed=True)
if tmp_auth_key:
expires_at = int(time.time()) + tmp_auth_key_expires_s
pq_inner_data = bytes(PQInnerDataTemp(
pq=rsa.get_byte_array(pq), p=p, q=q,
nonce=res_pq.nonce,
server_nonce=res_pq.server_nonce,
new_nonce=new_nonce,
expires_in=tmp_auth_key_expires_s
))
else:
pq_inner_data = bytes(PQInnerData(
pq=rsa.get_byte_array(pq), p=p, q=q,
nonce=res_pq.nonce,
@ -197,8 +210,33 @@ async def do_authentication(sender):
if not isinstance(dh_gen, DhGenOk):
raise AssertionError('Step 3.2 answer was %s' % dh_gen)
if tmp_auth_key:
# auth_key is only a tmp_auth_key here!
return auth_key, expires_at
else:
return auth_key, time_offset
async def bind_tmp_auth_key(sender, auth_key, tmp_auth_key):
"""
Binds a tmpAuthKey to an authkey.
:param sender: a connected `MTProtoender`.
:param auth_key: auth key to bind to
:parm tmp_auth_key: unbound tmp_auth_key
:return: None
"""
nonce = int.from_bytes(os.urandom(8), 'little', signed=True)
timestamp = tmp_auth_key.expires_at
encrypted_bind, msg_id = sender._state.get_encrypted_bind(nonce, auth_key, tmp_auth_key, timestamp)
bind_request = BindTempAuthKeyRequest(
perm_auth_key_id=auth_key.key_id,
nonce=nonce,
expires_at=timestamp,
encrypted_message=encrypted_bind
)
await sender.send(bind_request, msg_id=msg_id)
def get_int(byte_array, signed=True):
"""

View File

@ -47,7 +47,7 @@ class MTProtoSender:
"""
def __init__(self, auth_key, *, loggers,
retries=5, delay=1, auto_reconnect=True, connect_timeout=None,
auth_key_callback=None,
auth_key_callback=None, tmp_auth_key=None, tmp_auth_key_callback=None,
updates_queue=None, auto_reconnect_callback=None):
self._connection = None
self._loggers = loggers
@ -57,6 +57,7 @@ class MTProtoSender:
self._auto_reconnect = auto_reconnect
self._connect_timeout = connect_timeout
self._auth_key_callback = auth_key_callback
self._tmp_auth_key_callback = tmp_auth_key_callback
self._updates_queue = updates_queue
self._auto_reconnect_callback = auto_reconnect_callback
self._connect_lock = asyncio.Lock()
@ -79,7 +80,8 @@ class MTProtoSender:
# Preserving the references of the AuthKey and state is important
self.auth_key = auth_key or AuthKey(None)
self._state = MTProtoState(self.auth_key, loggers=self._loggers)
self.tmp_auth_key = tmp_auth_key or AuthKey(None)
self._state = MTProtoState(self.tmp_auth_key, loggers=self._loggers)
# Outgoing messages are put in a queue and sent in a batch.
# Note that here we're also storing their ``_RequestState``.
@ -152,7 +154,7 @@ class MTProtoSender:
"""
await self._disconnect()
def send(self, request, ordered=False):
def send(self, request, ordered=False, msg_id=None):
"""
This method enqueues the given request to be sent. Its send
state will be saved until a response arrives, and a ``Future``
@ -180,7 +182,7 @@ class MTProtoSender:
if not utils.is_list_like(request):
try:
state = RequestState(request)
state = RequestState(request, msg_id=msg_id)
except struct.error as e:
# "struct.error: required argument is not an integer" is not
# very helpful; log the request to find out what wasn't int.
@ -254,6 +256,25 @@ class MTProtoSender:
await asyncio.sleep(self._delay)
continue # next iteration we will try to reconnect
if not self.tmp_auth_key:
# establish tmp_auth_key here, but make sure to bind to the auth_key later
try:
if not await self._try_gen_tmp_auth_key(attempt):
continue # keep retrying until we have the tmp auth key
except (IOError, asyncio.TimeoutError) as e:
# Sometimes, specially during user-DC migrations,
# Telegram may close the connection during auth_key
# generation. If that's the case, we will need to
# connect again.
self._log.warning('Connection error %d during tmp_auth_key gen: %s: %s',
attempt, type(e).__name__, e)
# Whatever the IOError was, make sure to disconnect so we can
# reconnect cleanly after.
await self._connection.disconnect()
connected = False
await asyncio.sleep(self._delay)
continue # next iteration we will try to reconnect
break # all steps done, break retry loop
else:
if not connected:
@ -264,15 +285,21 @@ class MTProtoSender:
raise e
loop = helpers.get_running_loop()
# dirty hack, but otherwise we cannot send the binding of the tmp_auth_key
self._user_connected = True
# update key, this was unavailable at init of self._state
self._state.auth_key = self.tmp_auth_key
self._log.debug('Starting send loop')
self._send_loop_handle = loop.create_task(self._send_loop())
self._log.debug('Starting receive loop')
self._recv_loop_handle = loop.create_task(self._recv_loop())
# _disconnected only completes after manual disconnection
# or errors after which the sender cannot continue such
# as failing to reconnect or any unexpected error.
# both self.auth_key and self.tmp_auth_key are required for the binding
# and it can only take place after the send/recv loops as the
# the binding message needs to be sent encrypted
await authenticator.bind_tmp_auth_key(self, self.auth_key, self.tmp_auth_key)
if self._disconnected.done():
self._disconnected = loop.create_future()
@ -290,6 +317,27 @@ class MTProtoSender:
await asyncio.sleep(self._delay)
return False
async def _try_gen_tmp_auth_key(self, attempt):
plain = MTProtoPlainSender(self._connection, loggers=self._loggers)
try:
self._log.debug('New tmp_auth_key attempt %d...', attempt)
self.tmp_auth_key.key, self.tmp_auth_key.expires_at = \
await authenticator.do_authentication(plain, tmp_auth_key=True)
# This is *EXTREMELY* important since we don't control
# external references to the temporary authorization key, we must
# notify whenever we change it. This is crucial when we
# switch to different data centers.
if self._tmp_auth_key_callback:
self._tmp_auth_key_callback(self.tmp_auth_key)
self._log.info('tmp_auth_key generation success!')
return True
except (SecurityError, AssertionError) as e:
self._log.warning('Attempt %d at new tmp_auth_key failed: %s', attempt, e)
await asyncio.sleep(self._delay)
return False
async def _try_gen_auth_key(self, attempt):
plain = MTProtoPlainSender(self._connection, loggers=self._loggers)
try:
@ -366,7 +414,10 @@ class MTProtoSender:
self._reconnecting = False
# Start with a clean state (and thus session ID) to avoid old msgs
self._state.reset()
self._state.reset(keep_key=False)
self.tmp_auth_key = AuthKey(None)
if self._tmp_auth_key_callback:
self._tmp_auth_key_callback(self.tmp_auth_key)
retries = self._retries if self._auto_reconnect else 0

View File

@ -1,7 +1,7 @@
import os
import struct
import time
from hashlib import sha256
from hashlib import sha1, sha256
from collections import deque
from ..crypto import AES
@ -11,8 +11,7 @@ from ..tl.core import TLMessage
from ..tl.tlobject import TLRequest
from ..tl.functions import InvokeAfterMsgRequest
from ..tl.core.gzippacked import GzipPacked
from ..tl.types import BadServerSalt, BadMsgNotification
from ..tl.types import BadServerSalt, BadMsgNotification, BindAuthKeyInner
# N is not specified in https://core.telegram.org/mtproto/security_guidelines#checking-msg-id, but 500 is reasonable
MAX_RECENT_MSG_IDS = 500
@ -106,13 +105,37 @@ class MTProtoState:
return aes_key, aes_iv
@staticmethod
def _calc_key_v1(auth_key, msg_key, client):
"""
Calculate the key based on Telegram guidelines for MTProto 1,
specifying whether it's the client or not. See
https://core.telegram.org/mtproto/description_v1#defining-aes-key-and-initialization-vector
"""
x = 0 if client else 8
sha1a = sha1(msg_key + auth_key[x:x+32]).digest()
sha1b = sha1(auth_key[x+32:x+48] + msg_key + auth_key[x+48:x+64]).digest()
sha1c = sha1(auth_key[x+64:x+96] + msg_key).digest()
sha1d = sha1(msg_key + auth_key[x+96:x+128]).digest()
aes_key = sha1a[0:8] + sha1b[8:20] + sha1c[4:16]
aes_iv = sha1a[8:20] + sha1b[0:8] + sha1c[16:20] + sha1d[0:8]
return aes_key, aes_iv
def write_data_as_message(self, buffer, data, content_related,
*, after_id=None):
*, after_id=None, msg_id=None):
"""
Writes a message containing the given data into buffer.
Returns the message id.
"""
if msg_id is None:
# this should be the default - the binding of a tmpAuthKey is the only
# exception, as the msg_id of the encrypted BindAuthKeyInner needs to be equal to the
# msg_id of the outer message
# see: https://core.telegram.org/method/auth.bindTempAuthKey#encrypting-the-binding-message
msg_id = self._get_new_msg_id()
seq_no = self._get_seq_no(content_related)
if after_id is None:
@ -127,6 +150,32 @@ class MTProtoState:
buffer.write(body)
return msg_id
def get_encrypted_bind(self, nonce, auth_key, tmp_auth_key, timestamp):
# strangely, this should be encrypted using MTProto1
# see https://core.telegram.org/method/auth.bindTempAuthKey#encrypting-the-binding-message
msg_id = self._get_new_msg_id()
seq_no = 0
bind = BindAuthKeyInner(
nonce=nonce,
temp_auth_key_id=tmp_auth_key.key_id,
perm_auth_key_id=auth_key.key_id,
temp_session_id=self.id,
expires_at=timestamp
)
bind = bytes(bind)
assert len(bind) == 40
payload = os.urandom(int(128/8)) + struct.pack('<qii', msg_id, seq_no, len(bind)) + bind
padding = os.urandom(len(payload) % 16)
msg_key = sha1(payload).digest()[4:20]
aes_key, aes_iv = self._calc_key_v1(auth_key.key, msg_key, True)
key_id = struct.pack('<q', auth_key.key_id)
crypt = AES.encrypt_ige(payload + padding, aes_key, aes_iv)
encrypted_message = (key_id + msg_key + crypt)
return encrypted_message, msg_id
def encrypt_message_data(self, data):
"""
Encrypts the given message data using the current authorization key

View File

@ -10,9 +10,9 @@ class RequestState:
"""
__slots__ = ('container_id', 'msg_id', 'request', 'data', 'future', 'after')
def __init__(self, request, after=None):
def __init__(self, request, after=None, msg_id=None):
self.container_id = None
self.msg_id = None
self.msg_id = msg_id
self.request = request
self.data = bytes(request)
self.future = asyncio.Future()

View File

@ -32,6 +32,7 @@ class MemorySession(Session):
self._server_address = None
self._port = None
self._auth_key = None
self._tmp_auth_key = None
self._takeout_id = None
self._files = {}
@ -59,10 +60,18 @@ class MemorySession(Session):
def auth_key(self):
return self._auth_key
@property
def tmp_auth_key(self):
return self._tmp_auth_key
@auth_key.setter
def auth_key(self, value):
self._auth_key = value
@auth_key.setter
def tmp_auth_key(self, value):
self._tmp_auth_key = value
@property
def takeout_id(self):
return self._takeout_id

View File

@ -18,7 +18,7 @@ except ImportError as e:
sqlite3_err = type(e)
EXTENSION = '.session'
CURRENT_VERSION = 7 # database version
CURRENT_VERSION = 8 # database version
class SQLiteSession(MemorySession):
@ -30,13 +30,14 @@ class SQLiteSession(MemorySession):
through an official Telegram client to revoke the authorization.
"""
def __init__(self, session_id=None):
def __init__(self, session_id=None, store_tmp_auth_key_on_disk:bool=False):
if sqlite3 is None:
raise sqlite3_err
super().__init__()
self.filename = ':memory:'
self.save_entities = True
self.store_tmp_auth_key_on_disk = store_tmp_auth_key_on_disk
if session_id:
self.filename = session_id
@ -61,9 +62,10 @@ class SQLiteSession(MemorySession):
c.execute('select * from sessions')
tuple_ = c.fetchone()
if tuple_:
self._dc_id, self._server_address, self._port, key, \
self._dc_id, self._server_address, self._port, key, tmp_key, \
self._takeout_id = tuple_
self._auth_key = AuthKey(data=key)
self._tmp_auth_key = AuthKey(data=tmp_key)
c.close()
else:
@ -77,7 +79,8 @@ class SQLiteSession(MemorySession):
server_address text,
port integer,
auth_key blob,
takeout_id integer
takeout_id integer,
tmp_auth_key blob
)"""
,
"""entities (
@ -153,6 +156,9 @@ class SQLiteSession(MemorySession):
if old == 6:
old += 1
c.execute("alter table entities add column date integer")
if old == 7:
old += 1
c.execute("alter table sessions add column tmp_auth_key blob")
c.close()
@ -168,17 +174,27 @@ class SQLiteSession(MemorySession):
self._update_session_table()
# Fetch the auth_key corresponding to this data center
row = self._execute('select auth_key from sessions')
row = self._execute('select auth_key, tmp_auth_key from sessions')
if row and row[0]:
self._auth_key = AuthKey(data=row[0])
else:
self._auth_key = None
if row and row[1]:
self._tmp_auth_key = AuthKey(data=row[1])
else:
self._tmp_auth_key = None
@MemorySession.auth_key.setter
def auth_key(self, value):
self._auth_key = value
self._update_session_table()
@MemorySession.tmp_auth_key.setter
def tmp_auth_key(self, value):
self._tmp_auth_key = value
self._update_session_table()
@MemorySession.takeout_id.setter
def takeout_id(self, value):
self._takeout_id = value
@ -192,12 +208,13 @@ class SQLiteSession(MemorySession):
# some more work before being able to save auth_key's for
# multiple DCs. Probably done differently.
c.execute('delete from sessions')
c.execute('insert or replace into sessions values (?,?,?,?,?)', (
c.execute('insert or replace into sessions values (?,?,?,?,?,?)', (
self._dc_id,
self._server_address,
self._port,
self._auth_key.key if self._auth_key else b'',
self._takeout_id
self._takeout_id,
self._tmp_auth_key.key if (self.store_tmp_auth_key_on_disk and self._tmp_auth_key) else b''
))
c.close()