From e2fe3eb503e9ad9ab0d8eaf4c6c82281cda9bfc3 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Fri, 19 Oct 2018 13:24:52 +0200 Subject: [PATCH] Use new broken MessagePacker --- telethon/crypto/authkey.py | 6 + telethon/extensions/messagepacker.py | 116 ++++++++++++++++++++ telethon/helpers.py | 42 +------ telethon/network/mtprotolayer.py | 158 --------------------------- telethon/network/mtprotosender.py | 58 +++++----- telethon/network/mtprotostate.py | 10 +- 6 files changed, 162 insertions(+), 228 deletions(-) create mode 100644 telethon/extensions/messagepacker.py delete mode 100644 telethon/network/mtprotolayer.py diff --git a/telethon/crypto/authkey.py b/telethon/crypto/authkey.py index 1af8ab91..8475ec17 100644 --- a/telethon/crypto/authkey.py +++ b/telethon/crypto/authkey.py @@ -30,12 +30,18 @@ class AuthKey: self._key = self.aux_hash = self.key_id = None return + if isinstance(value, type(self)): + self._key, self.aux_hash, self.key_id = \ + value._key, value.aux_hash, value.key_id + return + self._key = value with BinaryReader(sha1(self._key).digest()) as reader: self.aux_hash = reader.read_long(signed=False) reader.read(4) self.key_id = reader.read_long(signed=False) + # TODO This doesn't really fit here, it's only used in authentication def calc_new_nonce_hash(self, new_nonce, number): """ Calculates the new nonce hash based on the current attributes. diff --git a/telethon/extensions/messagepacker.py b/telethon/extensions/messagepacker.py new file mode 100644 index 00000000..c4b410fe --- /dev/null +++ b/telethon/extensions/messagepacker.py @@ -0,0 +1,116 @@ +import asyncio +import collections +import io +import logging +import struct + +from ..tl import TLRequest +from ..tl.core.messagecontainer import MessageContainer +from ..tl.core.tlmessage import TLMessage + +__log__ = logging.getLogger(__name__) + + +class MessagePacker: + """ + This class packs `RequestState` as outgoing `TLMessages`. + + The purpose of this class is to support putting N `RequestState` into a + queue, and then awaiting for "packed" `TLMessage` in the other end. The + simplest case would be ``State -> TLMessage`` (1-to-1 relationship) but + for efficiency purposes it's ``States -> Container`` (N-to-1). + + This addresses several needs: outgoing messages will be smaller, so the + encryption and network overhead also is smaller. It's also a central + point where outgoing requests are put, and where ready-messages are get. + """ + def __init__(self, state, loop): + self._state = state + self._loop = loop + self._deque = collections.deque() + self._ready = asyncio.Event(loop=loop) + + def append(self, state): + self._deque.append(state) + self._ready.set() + + def extend(self, states): + self._deque.extend(states) + self._ready.set() + + async def get(self, cancellation): + """ + Returns (batch, data) if one or more items could be retrieved. + + If the cancellation occurs or only invalid items were in the + queue, (None, None) will be returned instead. + """ + if not self._deque: + self._ready.clear() + ready = self._loop.create_task(self._ready.wait()) + try: + done, pending = await asyncio.wait( + [ready, cancellation], + return_when=asyncio.FIRST_COMPLETED, + loop=self._loop + ) + except asyncio.CancelledError: + done = [cancellation] + + if cancellation in done: + ready.cancel() + return None, None + + buffer = io.BytesIO() + batch = [] + size = 0 + + # Fill a new batch to return while the size is small enough + while self._deque: + state = self._deque.popleft() + size += len(state.data) + TLMessage.SIZE_OVERHEAD + + if size <= MessageContainer.MAXIMUM_SIZE: + # TODO Implement back using after_id + state.msg_id = self._state.write_data_as_message( + buffer, state.data, isinstance(state.request, TLRequest) + ) + batch.append(state) + __log__.debug('Assigned msg_id = %d to %s (%x)', + state.msg_id, state.request.__class__.__name__, + id(state.request)) + continue + + # Put the item back since it can't be sent in this batch + self._deque.appendleft(state) + if batch: + break + + # If a single message exceeds the maximum size, then the + # message payload cannot be sent. Telegram would forcibly + # close the connection; message would never be confirmed. + state.future.set_exception( + ValueError('Request payload is too big')) + + size = 0 + continue + + if not batch: + return None, None + + if len(batch) > 1: + # Inlined code to pack several messages into a container + data = struct.pack( + ' MessageContainer.MAXIMUM_SIZE: - size -= MessageContainer.MAXIMUM_SIZE - if len(batch) > 1: - # Inlined code to pack several messages into a container - data = struct.pack( - ' MessageContainer.MAXIMUM_SIZE: - state.future.set_exception( - ValueError('Request payload is too big')) - return - - # This is the only requirement to make this work. - state.msg_id = self._state.write_data_as_message( - buffer, state.data, isinstance(state.request, TLRequest), - after_id=after_id - ) - __log__.debug('Assigned msg_id = %d to %s (%x)', - state.msg_id, state.request.__class__.__name__, - id(state.request)) - - # TODO Yield in the inner loop -> Telegram "Invalid container". Why? - for state in state_list: - if not isinstance(state, list): - yield write_state(state) - else: - after_id = None - for s in state: - yield write_state(s, after_id) - after_id = s.msg_id - - yield write_state(None) - - def __str__(self): - return str(self._connection) diff --git a/telethon/network/mtprotosender.py b/telethon/network/mtprotosender.py index a12b38f0..3d043715 100644 --- a/telethon/network/mtprotosender.py +++ b/telethon/network/mtprotosender.py @@ -3,9 +3,10 @@ import collections import logging from . import authenticator -from .mtprotolayer import MTProtoLayer +from ..extensions.messagepacker import MessagePacker from .mtprotoplainsender import MTProtoPlainSender from .requeststate import RequestState +from .mtprotostate import MTProtoState from ..tl.tlobject import TLRequest from .. import utils from ..errors import ( @@ -13,7 +14,6 @@ from ..errors import ( InvalidChecksumError, rpc_message_to_error ) from ..extensions import BinaryReader -from ..helpers import _ReadyQueue from ..tl.core import RpcResult, MessageContainer, GzipPacked from ..tl.functions.auth import LogOutRequest from ..tl.types import ( @@ -21,7 +21,7 @@ from ..tl.types import ( MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo, MsgsStateReq, MsgsStateInfo, MsgsAllInfo, MsgResendReq, upload ) -from ..utils import AsyncClassWrapper +from ..crypto import AuthKey __log__ = logging.getLogger(__name__) @@ -43,9 +43,9 @@ class MTProtoSender: """ def __init__(self, loop, *, retries=5, auto_reconnect=True, connect_timeout=None, - update_callback=None, + update_callback=None, auth_key=None, auth_key_callback=None, auto_reconnect_callback=None): - self._connection = None # MTProtoLayer, a.k.a. encrypted connection + self._connection = None self._loop = loop self._retries = retries self._auto_reconnect = auto_reconnect @@ -68,10 +68,13 @@ class MTProtoSender: self._send_loop_handle = None self._recv_loop_handle = None + # Preserving the references of the AuthKey and state is important + self._auth_key = auth_key or AuthKey(None) + self._state = MTProtoState(self._auth_key) + # Outgoing messages are put in a queue and sent in a batch. # Note that here we're also storing their ``_RequestState``. - # Note that it may also store lists (implying order must be kept). - self._send_queue = _ReadyQueue(self._loop) + self._send_queue = MessagePacker(self._state, self._loop) # Sent states are remembered until a response is received. self._pending_state = {} @@ -112,7 +115,7 @@ class MTProtoSender: __log__.info('User is already connected!') return - self._connection = MTProtoLayer(connection, auth_key) + self._connection = connection self._user_connected = True await self._connect() @@ -204,17 +207,16 @@ class MTProtoSender: .format(self._retries)) __log__.debug('Connection success!') - state = self._connection._state - if state.auth_key is None: - plain = MTProtoPlainSender(self._connection._connection) + if not self._auth_key: + plain = MTProtoPlainSender(self._connection) for retry in range(1, self._retries + 1): try: __log__.debug('New auth_key attempt {}...'.format(retry)) - state.auth_key, state.time_offset =\ + self._auth_key.key, self._state.time_offset =\ await authenticator.do_authentication(plain) if self._auth_key_callback: - await self._auth_key_callback(state.auth_key) + await self._auth_key_callback(self._auth_key) break except (SecurityError, AssertionError) as e: @@ -292,7 +294,7 @@ class MTProtoSender: self._reconnecting = False # Start with a clean state (and thus session ID) to avoid old msgs - self._connection.reset_state() + self._state.reset() retries = self._retries if self._auto_reconnect else 0 for retry in range(1, retries + 1): @@ -333,19 +335,19 @@ class MTProtoSender: self._last_acks.append(ack) self._pending_ack.clear() - state_list = await self._send_queue.get( - self._connection._connection.disconnected) + batch, data = await self._send_queue.get( + self._connection.disconnected) - if state_list is None: - break + if not data: + continue try: - await self._connection.send(state_list) + await self._connection.send(data) except Exception: __log__.exception('Unhandled error while sending data') continue - for state in state_list: + for state in batch: if not isinstance(state, list): if isinstance(state.request, TLRequest): self._pending_state[state.msg_id] = state @@ -364,7 +366,9 @@ class MTProtoSender: while self._user_connected and not self._reconnecting: __log__.debug('Receiving items from the network...') try: - message = await self._connection.recv() + # TODO Split except + body = await self._connection.recv() + message = self._state.decrypt_message_data(body) except TypeNotFoundError as e: __log__.info('Type %08x not found, remaining data %r', e.invalid_constructor_id, e.remaining) @@ -388,7 +392,7 @@ class MTProtoSender: else: __log__.warning('Invalid buffer %s', e) - self._connection._state.auth_key = None + self._auth_key.key = None self._start_reconnect() return except asyncio.IncompleteReadError: @@ -533,7 +537,7 @@ class MTProtoSender: """ bad_salt = message.obj __log__.debug('Handling bad salt for message %d', bad_salt.bad_msg_id) - self._connection._state.salt = bad_salt.new_server_salt + self._state.salt = bad_salt.new_server_salt states = self._pop_states(bad_salt.bad_msg_id) self._send_queue.extend(states) @@ -554,16 +558,16 @@ class MTProtoSender: if bad_msg.error_code in (16, 17): # Sent msg_id too low or too high (respectively). # Use the current msg_id to determine the right time offset. - to = self._connection._state.update_time_offset( + to = self._state.update_time_offset( correct_msg_id=message.msg_id) __log__.info('System clock is wrong, set time offset to %ds', to) elif bad_msg.error_code == 32: # msg_seqno too low, so just pump it up by some "large" amount # TODO A better fix would be to start with a new fresh session ID - self._connection._state._sequence += 64 + self._state._sequence += 64 elif bad_msg.error_code == 33: # msg_seqno too high never seems to happen but just in case - self._connection._state._sequence -= 16 + self._state._sequence -= 16 else: for state in states: state.future.set_exception(BadMessageError(bad_msg.error_code)) @@ -606,7 +610,7 @@ class MTProtoSender: """ # TODO https://goo.gl/LMyN7A __log__.debug('Handling new session created') - self._connection._state.salt = message.obj.server_salt + self._state.salt = message.obj.server_salt async def _handle_ack(self, message): """ diff --git a/telethon/network/mtprotostate.py b/telethon/network/mtprotostate.py index 61642620..99c78af4 100644 --- a/telethon/network/mtprotostate.py +++ b/telethon/network/mtprotostate.py @@ -38,11 +38,17 @@ class MTProtoState: authentication process, at which point the `MTProtoPlainSender` is better. """ def __init__(self, auth_key): - # Session IDs can be random on every connection - self.id = struct.unpack('q', os.urandom(8))[0] self.auth_key = auth_key self.time_offset = 0 self.salt = 0 + self.reset() + + def reset(self): + """ + Resets the state. + """ + # Session IDs can be random on every connection + self.id = struct.unpack('q', os.urandom(8))[0] self._sequence = 0 self._last_msg_id = 0