diff --git a/telethon/tl/custom/conversation.py b/telethon/tl/custom/conversation.py index 2f91cda1..f6b69cd9 100644 --- a/telethon/tl/custom/conversation.py +++ b/telethon/tl/custom/conversation.py @@ -138,7 +138,7 @@ class Conversation(ChatGetter): lambda x, y: x.reply_to_msg_id == y ) - async def _get_message( + def _get_message( self, target_message, indices, pending, timeout, condition): """ Gets the next desired message under the desired condition. @@ -192,11 +192,12 @@ class Conversation(ChatGetter): return future # Otherwise the next incoming response will be the one to use + # + # Note how we fill "pending" before giving control back to the + # event loop through "await". We want to register it as soon as + # possible, since any other task switch may arrive with the result. pending[target_id] = future - try: - return await self._get_result(future, start_time, timeout) - finally: - pending.pop(target_id, None) + return self._get_result(future, start_time, timeout, pending, target_id) async def get_edit(self, message=None, *, timeout=None): """ @@ -222,12 +223,9 @@ class Conversation(ChatGetter): return earliest_edit # Otherwise the next incoming response will be the one to use - future = asyncio.Future(loop=self._client.loop) + future = self._client.loop.create_future() self._pending_edits[target_id] = future - try: - return await self._get_result(future, start_time, timeout) - finally: - self._pending_edits.pop(target_id, None) + return await self._get_result(future, start_time, timeout, self._pending_edits, target_id) async def wait_read(self, message=None, *, timeout=None): """ @@ -246,10 +244,7 @@ class Conversation(ChatGetter): return self._pending_reads[target_id] = future - try: - return await self._get_result(future, start_time, timeout) - finally: - self._pending_reads.pop(target_id, None) + return await self._get_result(future, start_time, timeout, self._pending_reads, target_id) async def wait_event(self, event, *, timeout=None): """ @@ -284,20 +279,9 @@ class Conversation(ChatGetter): counter = Conversation._custom_counter Conversation._custom_counter += 1 - future = asyncio.Future(loop=self._client.loop) - - # We need the `async def` here because we want to block on the future - # from `_get_result` by using `await` on it. If we returned the future - # immediately we would `del` from `_custom` too early. - - async def result(): - try: - return await self._get_result(future, start_time, timeout) - finally: - del self._custom[counter] - + future = self._client.loop.create_future() self._custom[counter] = (event, future) - return await result() + return await self._get_result(future, start_time, timeout, self._custom, counter) async def _check_custom(self, built): for i, (ev, fut) in self._custom.items(): @@ -317,32 +301,23 @@ class Conversation(ChatGetter): self._incoming.append(response) - found = [] - for msg_id in self._pending_responses: - found.append(msg_id) + # Note: we don't remove from pending here, that's done on get result + for msg_id, future in self._pending_responses.items(): self._response_indices[msg_id] = len(self._incoming) + future.set_result(response) - for msg_id in found: - self._pending_responses.pop(msg_id).set_result(response) - - found.clear() - for msg_id in self._pending_replies: + for msg_id, future in self._pending_replies.items(): if msg_id == response.reply_to_msg_id: - found.append(msg_id) self._reply_indices[msg_id] = len(self._incoming) - - for msg_id in found: - self._pending_replies.pop(msg_id).set_result(response) + future.set_result(response) def _on_edit(self, message): message = message.message if message.chat_id != self.chat_id or message.out: return - found = [] - for msg_id, pending in self._pending_edits.items(): + for msg_id, future in self._pending_edits.items(): if msg_id < message.id: - found.append(msg_id) edit_ts = message.edit_date.timestamp() # We compare <= because edit_ts resolution is always to @@ -353,8 +328,7 @@ class Conversation(ChatGetter): else: self._edit_dates[msg_id] = message.edit_date.timestamp() - for msg_id in found: - self._pending_edits.pop(msg_id).set_result(message) + future.set_result(message) def _on_read(self, event): if event.chat_id != self.chat_id or event.inbox: @@ -379,7 +353,7 @@ class Conversation(ChatGetter): else: raise ValueError('No message was sent previously') - def _get_result(self, future, start_time, timeout): + async def _get_result(self, future, start_time, timeout, pending, target_id): due = self._total_due if timeout is None: timeout = self._timeout @@ -387,11 +361,14 @@ class Conversation(ChatGetter): if timeout is not None: due = min(due, start_time + timeout) - return asyncio.wait_for( - future, - timeout=None if due == float('inf') else due - time.time(), - loop=self._client.loop - ) + try: + return await asyncio.wait_for( + future, + timeout=None if due == float('inf') else due - time.time(), + loop=self._client.loop + ) + finally: + del pending[target_id] def _cancel_all(self, exception=None): for pending in itertools.chain(