mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-06-18 19:16:43 +00:00
Cancel tasks on reconnect instead of awaiting them
This prevents us from locking forever on any task that doesn't rely on cancellation tokens, in this case, Connection.recv()'s _recv_queue.get() would never complete after the server closed the connection. Additionally, working with cancellation tokens in asyncio is somewhat annoying to do. Last but not least removing the Connection._disconnected future avoids the need to use its state (if an exception was set it should be retrieved) to prevent asyncio from complaining, which it was before.
This commit is contained in:
parent
f2e77f4030
commit
7dece209a0
@ -38,7 +38,7 @@ class MessagePacker:
|
|||||||
self._deque.extend(states)
|
self._deque.extend(states)
|
||||||
self._ready.set()
|
self._ready.set()
|
||||||
|
|
||||||
async def get(self, cancellation):
|
async def get(self):
|
||||||
"""
|
"""
|
||||||
Returns (batch, data) if one or more items could be retrieved.
|
Returns (batch, data) if one or more items could be retrieved.
|
||||||
|
|
||||||
@ -47,19 +47,7 @@ class MessagePacker:
|
|||||||
"""
|
"""
|
||||||
if not self._deque:
|
if not self._deque:
|
||||||
self._ready.clear()
|
self._ready.clear()
|
||||||
ready = self._loop.create_task(self._ready.wait())
|
await self._ready.wait()
|
||||||
try:
|
|
||||||
done, pending = await asyncio.wait(
|
|
||||||
[ready, cancellation],
|
|
||||||
return_when=asyncio.FIRST_COMPLETED,
|
|
||||||
loop=self._loop
|
|
||||||
)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
done = [cancellation]
|
|
||||||
|
|
||||||
if cancellation in done:
|
|
||||||
ready.cancel()
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
batch = []
|
batch = []
|
||||||
|
@ -28,8 +28,7 @@ class Connection(abc.ABC):
|
|||||||
self._proxy = proxy
|
self._proxy = proxy
|
||||||
self._reader = None
|
self._reader = None
|
||||||
self._writer = None
|
self._writer = None
|
||||||
self._disconnected = self._loop.create_future()
|
self._connected = False
|
||||||
self._disconnected.set_result(None)
|
|
||||||
self._send_task = None
|
self._send_task = None
|
||||||
self._recv_task = None
|
self._recv_task = None
|
||||||
self._send_queue = asyncio.Queue(1)
|
self._send_queue = asyncio.Queue(1)
|
||||||
@ -77,7 +76,7 @@ class Connection(abc.ABC):
|
|||||||
self._reader, self._writer = \
|
self._reader, self._writer = \
|
||||||
await asyncio.open_connection(sock=s, loop=self._loop)
|
await asyncio.open_connection(sock=s, loop=self._loop)
|
||||||
|
|
||||||
self._disconnected = self._loop.create_future()
|
self._connected = True
|
||||||
self._send_task = self._loop.create_task(self._send_loop())
|
self._send_task = self._loop.create_task(self._send_loop())
|
||||||
self._recv_task = self._loop.create_task(self._recv_loop())
|
self._recv_task = self._loop.create_task(self._recv_loop())
|
||||||
|
|
||||||
@ -89,11 +88,7 @@ class Connection(abc.ABC):
|
|||||||
self._disconnect(error=None)
|
self._disconnect(error=None)
|
||||||
|
|
||||||
def _disconnect(self, error):
|
def _disconnect(self, error):
|
||||||
if not self._disconnected.done():
|
self._connected = False
|
||||||
if error:
|
|
||||||
self._disconnected.set_exception(error)
|
|
||||||
else:
|
|
||||||
self._disconnected.set_result(None)
|
|
||||||
|
|
||||||
while not self._send_queue.empty():
|
while not self._send_queue.empty():
|
||||||
self._send_queue.get_nowait()
|
self._send_queue.get_nowait()
|
||||||
@ -110,10 +105,6 @@ class Connection(abc.ABC):
|
|||||||
if self._writer:
|
if self._writer:
|
||||||
self._writer.close()
|
self._writer.close()
|
||||||
|
|
||||||
@property
|
|
||||||
def disconnected(self):
|
|
||||||
return asyncio.shield(self._disconnected, loop=self._loop)
|
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
"""
|
"""
|
||||||
Creates a clone of the connection.
|
Creates a clone of the connection.
|
||||||
@ -126,7 +117,7 @@ class Connection(abc.ABC):
|
|||||||
|
|
||||||
This method returns a coroutine.
|
This method returns a coroutine.
|
||||||
"""
|
"""
|
||||||
if self._disconnected.done():
|
if not self._connected:
|
||||||
raise ConnectionError('Not connected')
|
raise ConnectionError('Not connected')
|
||||||
|
|
||||||
return self._send_queue.put(data)
|
return self._send_queue.put(data)
|
||||||
@ -137,7 +128,7 @@ class Connection(abc.ABC):
|
|||||||
|
|
||||||
This method returns a coroutine.
|
This method returns a coroutine.
|
||||||
"""
|
"""
|
||||||
if self._disconnected.done():
|
if not self._connected:
|
||||||
raise ConnectionError('Not connected')
|
raise ConnectionError('Not connected')
|
||||||
|
|
||||||
result = await self._recv_queue.get()
|
result = await self._recv_queue.get()
|
||||||
@ -151,7 +142,7 @@ class Connection(abc.ABC):
|
|||||||
This loop is constantly popping items off the queue to send them.
|
This loop is constantly popping items off the queue to send them.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
while not self._disconnected.done():
|
while self._connected:
|
||||||
self._send(await self._send_queue.get())
|
self._send(await self._send_queue.get())
|
||||||
await self._writer.drain()
|
await self._writer.drain()
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
@ -166,7 +157,7 @@ class Connection(abc.ABC):
|
|||||||
This loop is constantly putting items on the queue as they're read.
|
This loop is constantly putting items on the queue as they're read.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
while not self._disconnected.done():
|
while self._connected:
|
||||||
data = await self._recv()
|
data = await self._recv()
|
||||||
await self._recv_queue.put(data)
|
await self._recv_queue.put(data)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
|
@ -279,11 +279,11 @@ class MTProtoSender:
|
|||||||
__log__.debug('Closing current connection...')
|
__log__.debug('Closing current connection...')
|
||||||
self._connection.disconnect()
|
self._connection.disconnect()
|
||||||
|
|
||||||
__log__.debug('Awaiting for the send loop before reconnecting...')
|
__log__.debug('Cancelling the send loop...')
|
||||||
await self._send_loop_handle
|
self._send_loop_handle.cancel()
|
||||||
|
|
||||||
__log__.debug('Awaiting for the receive loop before reconnecting...')
|
__log__.debug('Cancelling the receive loop...')
|
||||||
await self._recv_loop_handle
|
self._recv_loop_handle.cancel()
|
||||||
|
|
||||||
self._reconnecting = False
|
self._reconnecting = False
|
||||||
|
|
||||||
@ -334,8 +334,7 @@ class MTProtoSender:
|
|||||||
# This means that while it's not empty we can wait for
|
# This means that while it's not empty we can wait for
|
||||||
# more messages to be added to the send queue.
|
# more messages to be added to the send queue.
|
||||||
try:
|
try:
|
||||||
batch, data = await self._send_queue.get(
|
batch, data = await self._send_queue.get()
|
||||||
self._connection.disconnected)
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user