diff --git a/telethon/network/connection/connection.py b/telethon/network/connection/connection.py index de72a154..fd6d3aa5 100644 --- a/telethon/network/connection/connection.py +++ b/telethon/network/connection/connection.py @@ -11,11 +11,13 @@ class Connection(abc.ABC): """ The `Connection` class is a wrapper around ``asyncio.open_connection``. - Subclasses are meant to communicate with this class through a queue. + Subclasses will implement different transport modes as atomic operations, + which this class eases doing since the exposed interface simply puts and + gets complete data payloads to and from queues. - This class provides a reliable interface that will stay connected - under any conditions for as long as the user doesn't disconnect or - the input parameters to auto-reconnect dictate otherwise. + The only error that will raise from send and receive methods is + ``ConnectionError``, which will raise when attempting to send if + the client is disconnected (includes remote disconnections). """ def __init__(self, ip, port, *, loop, proxy=None): self._ip = ip @@ -81,12 +83,20 @@ class Connection(abc.ABC): def disconnect(self): """ - Disconnects from the server. + Disconnects from the server, and clears + pending outgoing and incoming messages. """ self._disconnected.set() + + while not self._send_queue.empty(): + self._send_queue.get_nowait() + if self._send_task: self._send_task.cancel() + while not self._recv_queue.empty(): + self._recv_queue.get_nowait() + if self._recv_task: self._recv_task.cancel() @@ -112,6 +122,9 @@ class Connection(abc.ABC): This method returns a coroutine. """ + if self._disconnected.is_set(): + raise ConnectionError('Not connected') + return self._send_queue.put(data) async def recv(self): @@ -120,11 +133,14 @@ class Connection(abc.ABC): This method returns a coroutine. """ - ok, result = await self._recv_queue.get() - if ok: + if self._disconnected.is_set(): + raise ConnectionError('Not connected') + + result = await self._recv_queue.get() + if result: return result else: - raise result from None + raise ConnectionError('The server closed the connection') async def _send_loop(self): """ @@ -137,7 +153,7 @@ class Connection(abc.ABC): except asyncio.CancelledError: pass except Exception: - logging.exception('Unhandled exception in the sending loop') + logging.exception('Unhandled exception in the send loop') self.disconnect() async def _recv_loop(self): @@ -147,11 +163,16 @@ class Connection(abc.ABC): try: while not self._disconnected.is_set(): data = await self._recv() - await self._recv_queue.put((True, data)) + await self._recv_queue.put(data) except asyncio.CancelledError: pass except Exception as e: - await self._recv_queue.put((False, e)) + if isinstance(e, asyncio.IncompleteReadError): + logging.info('The server closed the connection') + else: + logging.exception('Unhandled exception in the receive loop') + + await self._recv_queue.put(None) self.disconnect() @abc.abstractmethod