diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index 77768b78..62160857 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -261,7 +261,8 @@ class TelegramBaseClient(abc.ABC): base_logger: typing.Union[str, logging.Logger] = None, receive_updates: bool = True, catch_up: bool = False, - entity_cache_limit: int = 5000 + entity_cache_limit: int = 5000, + store_tmp_auth_key_on_disk: bool = True ): if not api_id or not api_hash: raise ValueError( @@ -287,7 +288,7 @@ class TelegramBaseClient(abc.ABC): # Determine what session object we have if isinstance(session, (str, pathlib.Path)): try: - session = SQLiteSession(str(session)) + session = SQLiteSession(str(session), store_tmp_auth_key_on_disk=store_tmp_auth_key_on_disk) except ImportError: import warnings warnings.warn( @@ -459,6 +460,7 @@ class TelegramBaseClient(abc.ABC): auto_reconnect=self._auto_reconnect, connect_timeout=self._timeout, auth_key_callback=self._auth_key_callback, + tmp_auth_key_callback=self._tmp_auth_key_callback, updates_queue=self._updates_queue, auto_reconnect_callback=self._handle_auto_reconnect ) @@ -558,6 +560,7 @@ class TelegramBaseClient(abc.ABC): return self.session.auth_key = self._sender.auth_key + self.session.tmp_auth_key = self._sender.tmp_auth_key self.session.save() try: @@ -770,6 +773,14 @@ class TelegramBaseClient(abc.ABC): self.session.auth_key = auth_key self.session.save() + def _tmp_auth_key_callback(self: 'TelegramClient', tmp_auth_key): + """ + Callback from the sender whenever it needed to generate a + new authorization key. This means we are not authorized. + """ + self.session.tmp_auth_key = tmp_auth_key + self.session.save() + # endregion # region Working with different connections/Data Centers diff --git a/telethon/crypto/authkey.py b/telethon/crypto/authkey.py index 37f65017..9946d7df 100644 --- a/telethon/crypto/authkey.py +++ b/telethon/crypto/authkey.py @@ -12,13 +12,16 @@ class AuthKey: Represents an authorization key, used to encrypt and decrypt messages sent to Telegram's data centers. """ - def __init__(self, data): + def __init__(self, data, expires_at: int = -1): """ Initializes a new authorization key. :param data: the data in bytes that represent this auth key. + :param expires_at: unix timestamp of key expiry """ self.key = data + self.expires_at = expires_at + self.tmp_key_bound = False @property def key(self): diff --git a/telethon/extensions/messagepacker.py b/telethon/extensions/messagepacker.py index c0f46f48..ac96c1b6 100644 --- a/telethon/extensions/messagepacker.py +++ b/telethon/extensions/messagepacker.py @@ -60,7 +60,8 @@ class MessagePacker: if size <= MessageContainer.MAXIMUM_SIZE: state.msg_id = self._state.write_data_as_message( buffer, state.data, isinstance(state.request, TLRequest), - after_id=state.after.msg_id if state.after else None + after_id=state.after.msg_id if state.after else None, + msg_id=state.msg_id ) batch.append(state) self._log.debug('Assigned msg_id = %d to %s (%x)', diff --git a/telethon/network/authenticator.py b/telethon/network/authenticator.py index ea476207..e32df6f9 100644 --- a/telethon/network/authenticator.py +++ b/telethon/network/authenticator.py @@ -7,7 +7,7 @@ import time from hashlib import sha1 from ..tl.types import ( - ResPQ, PQInnerData, ServerDHParamsFail, ServerDHParamsOk, + ResPQ, PQInnerData, PQInnerDataTemp, ServerDHParamsFail, ServerDHParamsOk, ServerDHInnerData, ClientDHInnerData, DhGenOk, DhGenRetry, DhGenFail ) from .. import helpers @@ -17,13 +17,15 @@ from ..extensions import BinaryReader from ..tl.functions import ( ReqPqMultiRequest, ReqDHParamsRequest, SetClientDHParamsRequest ) +from ..tl.functions.auth import BindTempAuthKeyRequest - -async def do_authentication(sender): +async def do_authentication(sender, tmp_auth_key=False, tmp_auth_key_expires_s = 24 * 3600): """ Executes the authentication process with the Telegram servers. :param sender: a connected `MTProtoPlainSender`. + :param tmp_auth_key: whether the flow for a tmp_auth_key should be executed + :param tmp_auth_key_expires_s: duration in s until tmp_auth_key expires :return: returns a (authorization key, time offset) tuple. """ # Step 1 sending: PQ Request, endianness doesn't matter since it's random @@ -41,12 +43,23 @@ async def do_authentication(sender): p, q = rsa.get_byte_array(p), rsa.get_byte_array(q) new_nonce = int.from_bytes(os.urandom(32), 'little', signed=True) - pq_inner_data = bytes(PQInnerData( - pq=rsa.get_byte_array(pq), p=p, q=q, - nonce=res_pq.nonce, - server_nonce=res_pq.server_nonce, - new_nonce=new_nonce - )) + if tmp_auth_key: + expires_at = int(time.time()) + tmp_auth_key_expires_s + pq_inner_data = bytes(PQInnerDataTemp( + pq=rsa.get_byte_array(pq), p=p, q=q, + nonce=res_pq.nonce, + server_nonce=res_pq.server_nonce, + new_nonce=new_nonce, + expires_in=tmp_auth_key_expires_s + )) + + else: + pq_inner_data = bytes(PQInnerData( + pq=rsa.get_byte_array(pq), p=p, q=q, + nonce=res_pq.nonce, + server_nonce=res_pq.server_nonce, + new_nonce=new_nonce + )) # sha_digest + data + random_bytes cipher_text, target_fingerprint = None, None @@ -197,8 +210,33 @@ async def do_authentication(sender): if not isinstance(dh_gen, DhGenOk): raise AssertionError('Step 3.2 answer was %s' % dh_gen) - return auth_key, time_offset + if tmp_auth_key: + # auth_key is only a tmp_auth_key here! + return auth_key, expires_at + else: + return auth_key, time_offset +async def bind_tmp_auth_key(sender, auth_key, tmp_auth_key): + """ + Binds a tmpAuthKey to an authkey. + + :param sender: a connected `MTProtoender`. + :param auth_key: auth key to bind to + :parm tmp_auth_key: unbound tmp_auth_key + :return: None + """ + nonce = int.from_bytes(os.urandom(8), 'little', signed=True) + timestamp = tmp_auth_key.expires_at + encrypted_bind, msg_id = sender._state.get_encrypted_bind(nonce, auth_key, tmp_auth_key, timestamp) + + bind_request = BindTempAuthKeyRequest( + perm_auth_key_id=auth_key.key_id, + nonce=nonce, + expires_at=timestamp, + encrypted_message=encrypted_bind + ) + + await sender.send(bind_request, msg_id=msg_id) def get_int(byte_array, signed=True): """ diff --git a/telethon/network/mtprotosender.py b/telethon/network/mtprotosender.py index d53f9ce8..f1ebcac6 100644 --- a/telethon/network/mtprotosender.py +++ b/telethon/network/mtprotosender.py @@ -47,7 +47,7 @@ class MTProtoSender: """ def __init__(self, auth_key, *, loggers, retries=5, delay=1, auto_reconnect=True, connect_timeout=None, - auth_key_callback=None, + auth_key_callback=None, tmp_auth_key=None, tmp_auth_key_callback=None, updates_queue=None, auto_reconnect_callback=None): self._connection = None self._loggers = loggers @@ -57,6 +57,7 @@ class MTProtoSender: self._auto_reconnect = auto_reconnect self._connect_timeout = connect_timeout self._auth_key_callback = auth_key_callback + self._tmp_auth_key_callback = tmp_auth_key_callback self._updates_queue = updates_queue self._auto_reconnect_callback = auto_reconnect_callback self._connect_lock = asyncio.Lock() @@ -79,7 +80,8 @@ class MTProtoSender: # Preserving the references of the AuthKey and state is important self.auth_key = auth_key or AuthKey(None) - self._state = MTProtoState(self.auth_key, loggers=self._loggers) + self.tmp_auth_key = tmp_auth_key or AuthKey(None) + self._state = MTProtoState(self.tmp_auth_key, loggers=self._loggers) # Outgoing messages are put in a queue and sent in a batch. # Note that here we're also storing their ``_RequestState``. @@ -152,7 +154,7 @@ class MTProtoSender: """ await self._disconnect() - def send(self, request, ordered=False): + def send(self, request, ordered=False, msg_id=None): """ This method enqueues the given request to be sent. Its send state will be saved until a response arrives, and a ``Future`` @@ -180,7 +182,7 @@ class MTProtoSender: if not utils.is_list_like(request): try: - state = RequestState(request) + state = RequestState(request, msg_id=msg_id) except struct.error as e: # "struct.error: required argument is not an integer" is not # very helpful; log the request to find out what wasn't int. @@ -254,6 +256,25 @@ class MTProtoSender: await asyncio.sleep(self._delay) continue # next iteration we will try to reconnect + if not self.tmp_auth_key: + # establish tmp_auth_key here, but make sure to bind to the auth_key later + try: + if not await self._try_gen_tmp_auth_key(attempt): + continue # keep retrying until we have the tmp auth key + except (IOError, asyncio.TimeoutError) as e: + # Sometimes, specially during user-DC migrations, + # Telegram may close the connection during auth_key + # generation. If that's the case, we will need to + # connect again. + self._log.warning('Connection error %d during tmp_auth_key gen: %s: %s', + attempt, type(e).__name__, e) + + # Whatever the IOError was, make sure to disconnect so we can + # reconnect cleanly after. + await self._connection.disconnect() + connected = False + await asyncio.sleep(self._delay) + continue # next iteration we will try to reconnect break # all steps done, break retry loop else: if not connected: @@ -264,15 +285,21 @@ class MTProtoSender: raise e loop = helpers.get_running_loop() + # dirty hack, but otherwise we cannot send the binding of the tmp_auth_key + self._user_connected = True + # update key, this was unavailable at init of self._state + self._state.auth_key = self.tmp_auth_key self._log.debug('Starting send loop') self._send_loop_handle = loop.create_task(self._send_loop()) self._log.debug('Starting receive loop') self._recv_loop_handle = loop.create_task(self._recv_loop()) - # _disconnected only completes after manual disconnection - # or errors after which the sender cannot continue such - # as failing to reconnect or any unexpected error. + # both self.auth_key and self.tmp_auth_key are required for the binding + # and it can only take place after the send/recv loops as the + # the binding message needs to be sent encrypted + await authenticator.bind_tmp_auth_key(self, self.auth_key, self.tmp_auth_key) + if self._disconnected.done(): self._disconnected = loop.create_future() @@ -290,6 +317,27 @@ class MTProtoSender: await asyncio.sleep(self._delay) return False + async def _try_gen_tmp_auth_key(self, attempt): + plain = MTProtoPlainSender(self._connection, loggers=self._loggers) + try: + self._log.debug('New tmp_auth_key attempt %d...', attempt) + self.tmp_auth_key.key, self.tmp_auth_key.expires_at = \ + await authenticator.do_authentication(plain, tmp_auth_key=True) + + # This is *EXTREMELY* important since we don't control + # external references to the temporary authorization key, we must + # notify whenever we change it. This is crucial when we + # switch to different data centers. + if self._tmp_auth_key_callback: + self._tmp_auth_key_callback(self.tmp_auth_key) + + self._log.info('tmp_auth_key generation success!') + return True + except (SecurityError, AssertionError) as e: + self._log.warning('Attempt %d at new tmp_auth_key failed: %s', attempt, e) + await asyncio.sleep(self._delay) + return False + async def _try_gen_auth_key(self, attempt): plain = MTProtoPlainSender(self._connection, loggers=self._loggers) try: @@ -366,7 +414,10 @@ class MTProtoSender: self._reconnecting = False # Start with a clean state (and thus session ID) to avoid old msgs - self._state.reset() + self._state.reset(keep_key=False) + self.tmp_auth_key = AuthKey(None) + if self._tmp_auth_key_callback: + self._tmp_auth_key_callback(self.tmp_auth_key) retries = self._retries if self._auto_reconnect else 0 diff --git a/telethon/network/mtprotostate.py b/telethon/network/mtprotostate.py index 0ceaa513..40dfb1de 100644 --- a/telethon/network/mtprotostate.py +++ b/telethon/network/mtprotostate.py @@ -1,7 +1,7 @@ import os import struct import time -from hashlib import sha256 +from hashlib import sha1, sha256 from collections import deque from ..crypto import AES @@ -11,8 +11,7 @@ from ..tl.core import TLMessage from ..tl.tlobject import TLRequest from ..tl.functions import InvokeAfterMsgRequest from ..tl.core.gzippacked import GzipPacked -from ..tl.types import BadServerSalt, BadMsgNotification - +from ..tl.types import BadServerSalt, BadMsgNotification, BindAuthKeyInner # N is not specified in https://core.telegram.org/mtproto/security_guidelines#checking-msg-id, but 500 is reasonable MAX_RECENT_MSG_IDS = 500 @@ -106,14 +105,38 @@ class MTProtoState: return aes_key, aes_iv + @staticmethod + def _calc_key_v1(auth_key, msg_key, client): + """ + Calculate the key based on Telegram guidelines for MTProto 1, + specifying whether it's the client or not. See + https://core.telegram.org/mtproto/description_v1#defining-aes-key-and-initialization-vector + """ + x = 0 if client else 8 + + sha1a = sha1(msg_key + auth_key[x:x+32]).digest() + sha1b = sha1(auth_key[x+32:x+48] + msg_key + auth_key[x+48:x+64]).digest() + sha1c = sha1(auth_key[x+64:x+96] + msg_key).digest() + sha1d = sha1(msg_key + auth_key[x+96:x+128]).digest() + + aes_key = sha1a[0:8] + sha1b[8:20] + sha1c[4:16] + aes_iv = sha1a[8:20] + sha1b[0:8] + sha1c[16:20] + sha1d[0:8] + + return aes_key, aes_iv + def write_data_as_message(self, buffer, data, content_related, - *, after_id=None): + *, after_id=None, msg_id=None): """ Writes a message containing the given data into buffer. Returns the message id. """ - msg_id = self._get_new_msg_id() + if msg_id is None: + # this should be the default - the binding of a tmpAuthKey is the only + # exception, as the msg_id of the encrypted BindAuthKeyInner needs to be equal to the + # msg_id of the outer message + # see: https://core.telegram.org/method/auth.bindTempAuthKey#encrypting-the-binding-message + msg_id = self._get_new_msg_id() seq_no = self._get_seq_no(content_related) if after_id is None: body = GzipPacked.gzip_if_smaller(content_related, data) @@ -127,6 +150,32 @@ class MTProtoState: buffer.write(body) return msg_id + def get_encrypted_bind(self, nonce, auth_key, tmp_auth_key, timestamp): + # strangely, this should be encrypted using MTProto1 + # see https://core.telegram.org/method/auth.bindTempAuthKey#encrypting-the-binding-message + msg_id = self._get_new_msg_id() + seq_no = 0 + bind = BindAuthKeyInner( + nonce=nonce, + temp_auth_key_id=tmp_auth_key.key_id, + perm_auth_key_id=auth_key.key_id, + temp_session_id=self.id, + expires_at=timestamp + ) + bind = bytes(bind) + assert len(bind) == 40 + + payload = os.urandom(int(128/8)) + struct.pack('