diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index 3fc2bcb9..e0fea3d2 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -1,5 +1,6 @@ import abc import asyncio +import collections import logging import platform import time @@ -322,8 +323,9 @@ class TelegramBaseClient(abc.ABC): # Some further state for subclasses self._event_builders = [] - self._conversations = {} - self._ids_in_conversations = {} # chat_id: count + + # {chat_id: {Conversation}} + self._conversations = collections.defaultdict(set) # Default parse mode self._parse_mode = markdown diff --git a/telethon/client/updates.py b/telethon/client/updates.py index c6df1546..69333a8a 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -382,8 +382,8 @@ class UpdateMethods(UserMethods): await self._get_difference(update, channel_id, pts_date) built = EventBuilderDict(self, update) - if self._conversations: - for conv in self._conversations.values(): + for conv_set in self._conversations.values(): + for conv in conv_set: ev = built[events.NewMessage] if ev: conv._on_new_message(ev) diff --git a/telethon/tl/custom/conversation.py b/telethon/tl/custom/conversation.py index f4fa234d..343bf6ec 100644 --- a/telethon/tl/custom/conversation.py +++ b/telethon/tl/custom/conversation.py @@ -394,12 +394,11 @@ class Conversation(ChatGetter): # Make sure we're the only conversation in this chat if it's exclusive chat_id = utils.get_peer_id(self._chat_peer) - count = self._client._ids_in_conversations.get(chat_id, 0) - if self._exclusive and count: + conv_set = self._client._conversations[chat_id] + if self._exclusive and conv_set: raise errors.AlreadyInConversationError() - self._client._ids_in_conversations[chat_id] = count + 1 - self._client._conversations[self._id] = self + conv_set.add(self) self._last_outgoing = 0 self._last_incoming = 0 @@ -426,14 +425,24 @@ class Conversation(ChatGetter): """ self._cancel_all() + async def cancel_all(self): + """ + Calls `cancel` on *all* conversations in this chat. + + Note that you should ``await`` this method, since it's meant to be + used outside of a context manager, and it needs to resolve the chat. + """ + chat_id = await self._client.get_peer_id(self._input_chat) + for conv in self._client._conversations[chat_id]: + conv.cancel() + async def __aexit__(self, exc_type, exc_val, exc_tb): chat_id = utils.get_peer_id(self._chat_peer) - if self._client._ids_in_conversations[chat_id] == 1: - del self._client._ids_in_conversations[chat_id] - else: - self._client._ids_in_conversations[chat_id] -= 1 + conv_set = self._client._conversations[chat_id] + conv_set.discard(self) + if not conv_set: + del self._client._conversations[chat_id] - del self._client._conversations[self._id] self._cancel_all() __enter__ = helpers._sync_enter