From c73b8eda26badb5b2d5fcb59a082e37f4f8ce3be Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 27 Feb 2019 11:24:47 +0100 Subject: [PATCH] Simplify filling RequestIter's buffer --- telethon/client/chats.py | 28 +++++++++------------------- telethon/client/dialogs.py | 20 ++++++++------------ telethon/client/messages.py | 28 ++++++++++------------------ telethon/requestiter.py | 16 ++++++++++++---- 4 files changed, 39 insertions(+), 53 deletions(-) diff --git a/telethon/client/chats.py b/telethon/client/chats.py index 7afe4655..83f8ce5e 100644 --- a/telethon/client/chats.py +++ b/telethon/client/chats.py @@ -68,7 +68,6 @@ class _ParticipantsIter(RequestIter): self.total = len(full.full_chat.participants.participants) - result = [] users = {user.id: user for user in full.users} for participant in full.full_chat.participants.participants: user = users[participant.user_id] @@ -77,26 +76,22 @@ class _ParticipantsIter(RequestIter): user = users[participant.user_id] user.participant = participant - result.append(user) + self.buffer.append(user) - self.left = len(result) - self.buffer = result + return True else: - result = [] self.total = 1 if self.limit != 0: user = await self.client.get_entity(entity) if self.filter_entity(user): user.participant = None - result.append(user) + self.buffer.append(user) - self.left = len(result) - self.buffer = result + return True async def _load_next_chunk(self): - result = [] if not self.requests: - return result + return True # Only care about the limit for the first request # (small amount of people, won't be aggressive). @@ -106,7 +101,7 @@ class _ParticipantsIter(RequestIter): # precise with being out of the offset/limit here. self.requests[0].limit = min(self.limit - self.requests[0].offset, 200) if self.requests[0].offset > self.limit: - return result + return True results = await self.client(self.requests) for i in reversed(range(len(self.requests))): @@ -125,9 +120,7 @@ class _ParticipantsIter(RequestIter): self.seen.add(participant.user_id) user = users[participant.user_id] user.participant = participant - result.append(user) - - return result + self.buffer.append(user) class _AdminLogIter(RequestIter): @@ -163,7 +156,6 @@ class _AdminLogIter(RequestIter): ) async def _load_next_chunk(self): - result = [] self.request.limit = min(self.left, 100) r = await self.client(self.request) entities = {utils.get_peer_id(x): x @@ -184,12 +176,10 @@ class _AdminLogIter(RequestIter): ev.action.message._finish_init( self.client, entities, self.entity) - result.append(custom.AdminLogEvent(ev, entities)) + self.buffer.append(custom.AdminLogEvent(ev, entities)) if len(r.events) < self.request.limit: - self.left = len(result) - - return result + return True class ChatMethods(UserMethods): diff --git a/telethon/client/dialogs.py b/telethon/client/dialogs.py index bf3f903f..453dfdec 100644 --- a/telethon/client/dialogs.py +++ b/telethon/client/dialogs.py @@ -29,8 +29,6 @@ class _DialogsIter(RequestIter): self.ignore_migrated = ignore_migrated async def _load_next_chunk(self): - result = [] - self.request.limit = min(self.left, 100) r = await self.client(self.request) @@ -62,34 +60,32 @@ class _DialogsIter(RequestIter): if not self.ignore_migrated or getattr( cd.entity, 'migrated_to', None) is None: - result.append(cd) + self.buffer.append(cd) if len(r.dialogs) < self.request.limit\ or not isinstance(r, types.messages.DialogsSlice): # Less than we requested means we reached the end, or # we didn't get a DialogsSlice which means we got all. - self.left = len(result) - - self.request.offset_date = r.messages[-1].date - self.request.offset_peer =\ - entities[utils.get_peer_id(r.dialogs[-1].peer)] + return True if self.request.offset_id == r.messages[-1].id: # In some very rare cases this will get stuck in an infinite # loop, where the offsets will get reused over and over. If # the new offset is the same as the one before, break already. - self.left = len(result) + return True self.request.offset_id = r.messages[-1].id self.request.exclude_pinned = True - return result + self.request.offset_date = r.messages[-1].date + self.request.offset_peer =\ + entities[utils.get_peer_id(r.dialogs[-1].peer)] class _DraftsIter(RequestIter): async def _init(self, **kwargs): r = await self.client(functions.messages.GetAllDraftsRequest()) - self.buffer = [custom.Draft._from_update(self.client, u) - for u in r.updates] + self.buffer.extend(custom.Draft._from_update(self.client, u) + for u in r.updates) async def _load_next_chunk(self): return [] diff --git a/telethon/client/messages.py b/telethon/client/messages.py index 87d49be6..e87d80a0 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -139,8 +139,6 @@ class _MessagesIter(RequestIter): self.last_id = 0 if self.reverse else float('inf') async def _load_next_chunk(self): - result = [] - self.request.limit = min(self.left, self.batch_size) if self.reverse and self.request.limit != self.batch_size: # Remember that we need -limit when going in reverse @@ -159,8 +157,7 @@ class _MessagesIter(RequestIter): continue if not self._message_in_range(message): - self.left = len(result) - break + return True # There has been reports that on bad connections this method # was returning duplicated IDs sometimes. Using ``last_id`` @@ -168,15 +165,15 @@ class _MessagesIter(RequestIter): # IDs are returned in descending order (or asc if reverse). self.last_id = message.id message._finish_init(self.client, entities, self.entity) - result.append(message) + self.buffer.append(message) if len(r.messages) < self.request.limit: - self.left = len(result) + return True # Get the last message that's not empty (in some rare cases # it can happen that the last message is :tl:`MessageEmpty`) - if result: - self._update_offset(result[-1]) + if self.buffer: + self._update_offset(self.buffer[-1]) else: # There are some cases where all the messages we get start # being empty. This can happen on migrated mega-groups if @@ -184,9 +181,7 @@ class _MessagesIter(RequestIter): # acts incredibly weird sometimes. Messages are returned but # only "empty", not their contents. If this is the case we # should just give up since there won't be any new Message. - self.left = len(result) - - return result + return True def _message_in_range(self, message): """ @@ -258,7 +253,7 @@ class _IDsIter(RequestIter): from_id = await self.client.get_peer_id(entity) if isinstance(r, types.messages.MessagesNotModified): - self.buffer = [None] * len(ids) + self.buffer.extend(None for _ in ids) return entities = {utils.get_peer_id(x): x @@ -272,19 +267,16 @@ class _IDsIter(RequestIter): # The passed message IDs may not belong to the desired entity # since the user can enter arbitrary numbers which can belong to # arbitrary chats. Validate these unless ``from_id is None``. - result = [] for message in r.messages: if isinstance(message, types.MessageEmpty) or ( from_id and message.chat_id != from_id): - result.append(None) + self.buffer.append(None) else: message._finish_init(self.client, entities, entity) - result.append(message) - - self.buffer = result + self.buffer.append(message) async def _load_next_chunk(self): - return [] # no next chunk, all done in init + return True # no next chunk, all done in init class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): diff --git a/telethon/requestiter.py b/telethon/requestiter.py index ca9c8d72..111d4509 100644 --- a/telethon/requestiter.py +++ b/telethon/requestiter.py @@ -55,7 +55,8 @@ class RequestIter(abc.ABC): """ async def __anext__(self): - if self.buffer is (): + if self.buffer is None: + self.buffer = [] await self._init(**self.kwargs) if self.left <= 0: # <= 0 because subclasses may change it @@ -71,7 +72,9 @@ class RequestIter(abc.ABC): self.last_load = time.time() self.index = 0 - self.buffer = await self._load_next_chunk() + self.buffer = [] + if await self._load_next_chunk(): + self.left = len(self.buffer) if not self.buffer: raise StopAsyncIteration @@ -82,7 +85,7 @@ class RequestIter(abc.ABC): return result def __aiter__(self): - self.buffer = () + self.buffer = None self.index = 0 self.last_load = 0 self.left = self.limit @@ -113,7 +116,12 @@ class RequestIter(abc.ABC): async def _load_next_chunk(self): """ Called when the next chunk is necessary. - It should *always* return a `list`. + + It should extend the `buffer` with new items. + + It should return ``True`` if it's the last chunk, + after which moment the method won't be called again + during the same iteration. """ raise NotImplementedError