From 6d6c1917bcf28d7ead91ddf34e179c27bd709ccb Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 27 Feb 2019 10:04:12 +0100 Subject: [PATCH] Implement iterator over message by IDs --- telethon/client/messages.py | 97 ++++++++++++++++++------------------- telethon/requestiter.py | 2 + 2 files changed, 48 insertions(+), 51 deletions(-) diff --git a/telethon/client/messages.py b/telethon/client/messages.py index 1181b14f..9750003d 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -228,7 +228,8 @@ class _MessagesIter(RequestIter): class _IDsIter(RequestIter): - async def _init(self, entity, from_user, ids): + async def _init(self, entity, ids): + # TODO We never actually split IDs in chunks, but maybe we should if not utils.is_list_like(ids): self.ids = [ids] elif not ids: @@ -238,10 +239,52 @@ class _IDsIter(RequestIter): else: self.ids = ids - raise NotImplementedError + if entity: + entity = await self.client.get_input_entity(entity) + + self.total = len(ids) + + from_id = None # By default, no need to validate from_id + if isinstance(entity, (types.InputChannel, types.InputPeerChannel)): + try: + r = await self.client( + functions.channels.GetMessagesRequest(entity, ids)) + except errors.MessageIdsEmptyError: + # All IDs were invalid, use a dummy result + r = types.messages.MessagesNotModified(len(ids)) + else: + r = await self.client(functions.messages.GetMessagesRequest(ids)) + if entity: + from_id = await self.client.get_peer_id(entity) + + if isinstance(r, types.messages.MessagesNotModified): + self.buffer = [None] * len(ids) + return + + entities = {utils.get_peer_id(x): x + for x in itertools.chain(r.users, r.chats)} + + # Telegram seems to return the messages in the order in which + # we asked them for, so we don't need to check it ourselves, + # unless some messages were invalid in which case Telegram + # may decide to not send them at all. + # + # The passed message IDs may not belong to the desired entity + # since the user can enter arbitrary numbers which can belong to + # arbitrary chats. Validate these unless ``from_id is None``. + result = [] + for message in r.messages: + if isinstance(message, types.MessageEmpty) or ( + from_id and message.chat_id != from_id): + result.append(None) + else: + message._finish_init(self.client, entities, entity) + result.append(message) + + self.buffer = result async def _load_next_chunk(self): - raise NotImplementedError + return [] # no next chunk, all done in init class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): @@ -850,51 +893,3 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): # endregion # endregion - - # region Private methods - - async def _iter_ids(self, entity, ids, total): - """ - Special case for `iter_messages` when it should only fetch some IDs. - """ - if total: - total[0] = len(ids) - - from_id = None # By default, no need to validate from_id - if isinstance(entity, (types.InputChannel, types.InputPeerChannel)): - try: - r = await self( - functions.channels.GetMessagesRequest(entity, ids)) - except errors.MessageIdsEmptyError: - # All IDs were invalid, use a dummy result - r = types.messages.MessagesNotModified(len(ids)) - else: - r = await self(functions.messages.GetMessagesRequest(ids)) - if entity: - from_id = utils.get_peer_id(entity) - - if isinstance(r, types.messages.MessagesNotModified): - for _ in ids: - await yield_(None) - return - - entities = {utils.get_peer_id(x): x - for x in itertools.chain(r.users, r.chats)} - - # Telegram seems to return the messages in the order in which - # we asked them for, so we don't need to check it ourselves, - # unless some messages were invalid in which case Telegram - # may decide to not send them at all. - # - # The passed message IDs may not belong to the desired entity - # since the user can enter arbitrary numbers which can belong to - # arbitrary chats. Validate these unless ``from_id is None``. - for message in r.messages: - if isinstance(message, types.MessageEmpty) or ( - from_id and message.chat_id != from_id): - await yield_(None) - else: - message._finish_init(self, entities, entity) - await yield_(message) - - # endregion diff --git a/telethon/requestiter.py b/telethon/requestiter.py index ba897a2a..98a1cfb6 100644 --- a/telethon/requestiter.py +++ b/telethon/requestiter.py @@ -48,6 +48,8 @@ class RequestIter(abc.ABC): to avoid forgetting or misspelling any of them. This method may ``raise StopAsyncIteration`` if it cannot continue. + + This method may actually fill the initial buffer if it needs to. """ async def __anext__(self):