diff --git a/telethon/helpers.py b/telethon/helpers.py index 1452aad2..e0cced62 100644 --- a/telethon/helpers.py +++ b/telethon/helpers.py @@ -1,4 +1,5 @@ """Various helpers not related to the Telegram API itself""" +import asyncio import collections import os import struct @@ -87,4 +88,42 @@ class TotalList(list): return '[{}, total={}]'.format( ', '.join(repr(x) for x in self), self.total) + +class _ReadyQueue: + """ + A queue list that supports an arbitrary cancellation token for `get`. + """ + def __init__(self, loop): + self._list = [] + self._loop = loop + self._ready = asyncio.Event(loop=loop) + + def append(self, item): + self._list.append(item) + self._ready.set() + + def extend(self, items): + self._list.extend(items) + self._ready.set() + + async def get(self, cancellation): + """ + Returns a list of all the items added to the queue until now and + clears the list from the queue itself. Returns ``None`` if cancelled. + """ + ready = asyncio.ensure_future(self._ready.wait(), loop=self._loop) + done, pending = await asyncio.wait( + [ready, cancellation], + return_when=asyncio.FIRST_COMPLETED, + loop=self._loop + ) + if cancellation in done: + ready.cancel() + return None + + result = self._list + self._list = [] + self._ready.clear() + return result + # endregion diff --git a/telethon/network/connection/connection.py b/telethon/network/connection/connection.py index be609a57..abe0fa60 100644 --- a/telethon/network/connection/connection.py +++ b/telethon/network/connection/connection.py @@ -21,6 +21,7 @@ class Connection(abc.ABC): self._writer = None self._disconnected = asyncio.Event(loop=loop) self._disconnected.set() + self._disconnected_future = None self._send_task = None self._recv_task = None self._send_queue = asyncio.Queue(1) @@ -34,6 +35,7 @@ class Connection(abc.ABC): self._ip, self._port, loop=self._loop) self._disconnected.clear() + self._disconnected_future = None self._send_task = self._loop.create_task(self._send_loop()) self._recv_task = self._loop.create_task(self._recv_loop()) @@ -46,6 +48,13 @@ class Connection(abc.ABC): self._recv_task.cancel() self._writer.close() + @property + def disconnected(self): + if not self._disconnected_future: + self._disconnected_future = asyncio.ensure_future( + self._disconnected.wait(), loop=self._loop) + return self._disconnected_future + def clone(self): """ Creates a clone of the connection. diff --git a/telethon/network/mtprotosender.py b/telethon/network/mtprotosender.py index 604cd0fb..582945ec 100644 --- a/telethon/network/mtprotosender.py +++ b/telethon/network/mtprotosender.py @@ -9,6 +9,7 @@ from ..errors import ( BadMessageError, TypeNotFoundError, rpc_message_to_error ) from ..extensions import BinaryReader +from ..helpers import _ReadyQueue from ..tl.core import RpcResult, MessageContainer, GzipPacked from ..tl.functions.auth import LogOutRequest from ..tl.types import ( @@ -64,10 +65,7 @@ class MTProtoSender: # Outgoing messages are put in a queue and sent in a batch. # Note that here we're also storing their ``_RequestState``. # Note that it may also store lists (implying order must be kept). - # - # TODO Abstract this queue away? - self._send_queue = [] - self._send_ready = asyncio.Event(loop=self._loop) + self._send_queue = _ReadyQueue(self._loop) # Sent states are remembered until a response is received. self._pending_state = {} @@ -192,7 +190,6 @@ class MTProtoSender: if not utils.is_list_like(request): state = RequestState(request, self._loop) self._send_queue.append(state) - self._send_ready.set() return state.future else: states = [] @@ -206,7 +203,6 @@ class MTProtoSender: else: self._send_queue.extend(states) - self._send_ready.set() return futures @property @@ -333,29 +329,15 @@ class MTProtoSender: if self._pending_ack: ack = RequestState(MsgsAck(list(self._pending_ack)), self._loop) self._send_queue.append(ack) - self._send_ready.set() self._last_acks.append(ack) self._pending_ack.clear() - queue = asyncio.ensure_future( - self._send_ready.wait(), loop=self._loop) + state_list = await self._send_queue.get( + self._connection._connection.disconnected) - disconnected = asyncio.ensure_future( - self._connection._connection._disconnected.wait()) - - # Basically using the disconnected as a cancellation token - done, pending = await asyncio.wait( - [queue, disconnected], - return_when=asyncio.FIRST_COMPLETED, - loop=self._loop - ) - if disconnected in done: + if state_list is None: break - state_list = self._send_queue - self._send_queue = [] - self._send_ready.clear() - # TODO Debug logs to notify which messages are being sent # TODO Try sending them while no future was cancelled? # TODO Handle timeout, cancelled, arbitrary errors @@ -425,7 +407,6 @@ class MTProtoSender: error = rpc_message_to_error(rpc_result.error) self._send_queue.append( RequestState(MsgsAck([state.msg_id]), loop=self._loop)) - self._send_ready.set() if not state.future.cancelled(): state.future.set_exception(error) @@ -494,12 +475,10 @@ class MTProtoSender: try: self._send_queue.append( self._pending_state.pop(bad_salt.bad_msg_id)) - self._send_ready.set() except KeyError: for ack in self._pending_ack: if ack.msg_id == bad_salt.bad_msg_id: self._send_queue.append(ack) - self._send_ready.set() return __log__.info('Message %d not resent due to bad salt', @@ -539,7 +518,6 @@ class MTProtoSender: # Messages are to be re-sent once we've corrected the issue if state: self._send_queue.append(state) - self._send_ready.set() else: # TODO Generic method that may return from the acks too # May be MsgsAck, those are not saved in pending messages @@ -627,7 +605,6 @@ class MTProtoSender: self._send_queue.append(RequestState(MsgsStateInfo( req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)), loop=self._loop)) - self._send_ready.set() async def _handle_msg_all(self, message): """