From 36eb1b1009eaab4836d050cc0e8499b004e93401 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Tue, 26 Feb 2019 20:26:40 +0100 Subject: [PATCH 01/15] Create a new RequestIter ABC to deal with iter methods This should make it easier to maintain these methods, increase reusability, and get rid of the async_generator dependency. In the future, people could use this to more easily deal with raw API themselves. --- telethon/client/messages.py | 162 +++++++++++++++++++++++++++++++++++- telethon/requestiter.py | 99 ++++++++++++++++++++++ 2 files changed, 257 insertions(+), 4 deletions(-) create mode 100644 telethon/requestiter.py diff --git a/telethon/client/messages.py b/telethon/client/messages.py index 2ad8d4b2..2f928c69 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -2,13 +2,152 @@ import asyncio import itertools import time -from async_generator import async_generator, yield_ - from .messageparse import MessageParseMethods from .uploads import UploadMethods from .buttons import ButtonMethods from .. import helpers, utils, errors from ..tl import types, functions +from ..requestiter import RequestIter + + +class _GetHistoryIter(RequestIter): + async def _init(self, entity, offset_id, min_id, max_id, from_user, batch_size, offset_date, add_offset): + self.entity = await self.client.get_input_entity(entity) + + # Telegram doesn't like min_id/max_id. If these IDs are low enough + # (starting from last_id - 100), the request will return nothing. + # + # We can emulate their behaviour locally by setting offset = max_id + # and simply stopping once we hit a message with ID <= min_id. + if self.reverse: + offset_id = max(offset_id, min_id) + if offset_id and max_id: + if max_id - offset_id <= 1: + raise StopAsyncIteration + + if not max_id: + max_id = float('inf') + else: + offset_id = max(offset_id, max_id) + if offset_id and min_id: + if offset_id - min_id <= 1: + raise StopAsyncIteration + + if self.reverse: + if offset_id: + offset_id += 1 + else: + offset_id = 1 + + if from_user: + from_user = await self.client.get_input_entity(from_user) + if not isinstance(from_user, ( + types.InputPeerUser, types.InputPeerSelf)): + from_user = None # Ignore from_user unless it's a user + + self.from_id = (await self.client.get_peer_id(from_user)) if from_user else None + + self.request = functions.messages.GetHistoryRequest( + peer=entity, + limit=1, + offset_date=offset_date, + offset_id=offset_id, + min_id=0, + max_id=0, + add_offset=add_offset, + hash=0 + ) + + if self.limit == 0: + # No messages, but we still need to know the total message count + result = await self.client(self.request) + if isinstance(result, types.messages.MessagesNotModified): + self.total = result.count + else: + self.total = getattr(result, 'count', len(result.messages)) + raise StopAsyncIteration + + # When going in reverse we need an offset of `-limit`, but we + # also want to respect what the user passed, so add them together. + if self.reverse: + self.request.add_offset -= batch_size + + if self.wait_time is None: + self.wait_time = 1 if self.limit > 3000 else 0 + + # Telegram has a hard limit of 100. + # We don't need to fetch 100 if the limit is less. + self.batch_size = min(max(batch_size, 1), min(100, self.limit)) + self.add_offset = add_offset + self.max_id = max_id + self.min_id = min_id + self.last_id = 0 if self.reverse else float('inf') + + async def _load_next_chunk(self): + result = [] + + self.request.limit = min(self.left, self.batch_size) + if self.reverse and self.request.limit != self.batch_size: + # Remember that we need -limit when going in reverse + self.request.add_offset = self.add_offset - self.request.limit + + r = await self.client(self.request) + self.total = getattr(r, 'count', len(r.messages)) + + entities = {utils.get_peer_id(x): x + for x in itertools.chain(r.users, r.chats)} + + messages = reversed(r.messages) if self.reverse else r.messages + for message in messages: + if (isinstance(message, types.MessageEmpty) + or self.from_id and message.from_id != self.from_id): + continue + + # TODO We used to yield and return here (stopping the iterator) + # How should we go around that here? + if self.reverse: + if message.id <= self.last_id or message.id >= self.max_id: + break + else: + if message.id >= self.last_id or message.id <= self.min_id: + break + + # There has been reports that on bad connections this method + # was returning duplicated IDs sometimes. Using ``last_id`` + # is an attempt to avoid these duplicates, since the message + # IDs are returned in descending order (or asc if reverse). + self.last_id = message.id + message._finish_init(self.client, entities, self.entity) + result.append(message) + + if len(r.messages) < self.request.limit: + return result + + # Find the first message that's not empty (in some rare cases + # it can happen that the last message is :tl:`MessageEmpty`) + last_message = None + messages = r.messages if self.reverse else reversed(r.messages) + for m in messages: + if not isinstance(m, types.MessageEmpty): + last_message = m + break + + # TODO If it's None, we used to break (ending the iterator) + # Similar case as the return above. + if last_message is not None: + # There are some cases where all the messages we get start + # being empty. This can happen on migrated mega-groups if + # the history was cleared, and we're using search. Telegram + # acts incredibly weird sometimes. Messages are returned but + # only "empty", not their contents. If this is the case we + # should just give up since there won't be any new Message. + self.request.offset_id = last_message.id + self.request.offset_date = last_message.date + if self.reverse: + # We want to skip the one we already have + self.request.offset_id += 1 + + return result class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): @@ -17,7 +156,6 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): # region Message retrieval - @async_generator async def iter_messages( self, entity, limit=None, *, offset_date=None, offset_id=0, max_id=0, min_id=0, add_offset=0, search=None, filter=None, @@ -133,6 +271,23 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): an higher limit, so you're free to set the ``batch_size`` that you think may be good. """ + # TODO Handle global search + # TODO Handle search + # TODO Handle yield IDs + return _GetHistoryIter( + self, + limit=limit, + wait_time=wait_time, + entity=entity, + reverse=reverse, + offset_id=offset_id, + min_id=min_id, + max_id=max_id, + from_user=from_user, + batch_size=batch_size, + offset_date=offset_date, + add_offset=add_offset + ) # Note that entity being ``None`` is intended to get messages by # ID under no specific chat, and also to request a global search. if entity: @@ -802,7 +957,6 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): # region Private methods - @async_generator async def _iter_ids(self, entity, ids, total): """ Special case for `iter_messages` when it should only fetch some IDs. diff --git a/telethon/requestiter.py b/telethon/requestiter.py new file mode 100644 index 00000000..d8ccff53 --- /dev/null +++ b/telethon/requestiter.py @@ -0,0 +1,99 @@ +import abc +import asyncio +import time + + +class RequestIter(abc.ABC): + """ + Helper class to deal with requests that need offsets to iterate. + + It has some facilities, such as automatically sleeping a desired + amount of time between requests if needed (but not more). + + Can be used synchronously if the event loop is not running and + as an asynchronous iterator otherwise. + + `limit` is the total amount of items that the iterator should return. + This is handled on this base class, and will be always ``>= 0``. + + `left` will be reset every time the iterator is used and will indicate + the amount of items that should be emitted left, so that subclasses can + be more efficient and fetch only as many items as they need. + + Iterators may be used with ``reversed``, and their `reverse` flag will + be set to ``True`` if that's the case. Note that if this flag is set, + `buffer` should be filled in reverse too. + """ + def __init__(self, client, limit, *, reverse=False, wait_time=None, **kwargs): + self.client = client + self.reverse = reverse + self.wait_time = wait_time + self.kwargs = kwargs + self.limit = max(float('inf') if limit is None else limit, 0) + self.left = None + self.buffer = None + self.index = None + self.total = None + self.last_load = None + + async def _init(self, **kwargs): + """ + Called when asynchronous initialization is necessary. All keyword + arguments passed to `__init__` will be forwarded here, and it's + preferable to use named arguments in the subclasses without defaults + to avoid forgetting or misspelling any of them. + + This method may ``raise StopAsyncIteration`` if it cannot continue. + """ + + async def __anext__(self): + if self.buffer is (): + await self._init(**self.kwargs) + + if self.index == len(self.buffer): + # asyncio will handle times <= 0 to sleep 0 seconds + if self.wait_time: + await asyncio.sleep( + self.wait_time - (time.time() - self.last_load), + loop=self.client.loop + ) + self.last_load = time.time() + + self.index = 0 + self.buffer = await self._load_next_chunk() + + if not self.buffer: + raise StopAsyncIteration + + result = self.buffer[self.index] + self.left -= 1 + self.index += 1 + return result + + def __aiter__(self): + self.buffer = () + self.index = 0 + self.last_load = 0 + self.left = self.limit + return self + + def __iter__(self): + if self.client.loop.is_running(): + raise RuntimeError( + 'You must use "async for" if the event loop ' + 'is running (i.e. you are inside an "async def")' + ) + + raise NotImplementedError('lol!') + + @abc.abstractmethod + async def _load_next_chunk(self): + """ + Called when the next chunk is necessary. + It should *always* return a `list`. + """ + raise NotImplementedError + + def __reversed__(self): + self.reverse = not self.reverse + return self # __aiter__ will be called after, too From 19f38d6733d055968f6b1d48efbae6bc94180a2d Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Tue, 26 Feb 2019 21:04:46 +0100 Subject: [PATCH 02/15] Implement iter_messages with search --- telethon/client/messages.py | 204 +++++++++++++++++++++++++++++++++--- 1 file changed, 189 insertions(+), 15 deletions(-) diff --git a/telethon/client/messages.py b/telethon/client/messages.py index 2f928c69..f4cb9cb7 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -150,6 +150,163 @@ class _GetHistoryIter(RequestIter): return result +class _SearchMessagesIter(RequestIter): + async def _init(self, entity, offset_id, min_id, max_id, from_user, batch_size, offset_date, add_offset, filter, search): + self.entity = await self.client.get_input_entity(entity) + + # Telegram doesn't like min_id/max_id. If these IDs are low enough + # (starting from last_id - 100), the request will return nothing. + # + # We can emulate their behaviour locally by setting offset = max_id + # and simply stopping once we hit a message with ID <= min_id. + if self.reverse: + offset_id = max(offset_id, min_id) + if offset_id and max_id: + if max_id - offset_id <= 1: + raise StopAsyncIteration + + if not max_id: + max_id = float('inf') + else: + offset_id = max(offset_id, max_id) + if offset_id and min_id: + if offset_id - min_id <= 1: + raise StopAsyncIteration + + if self.reverse: + if offset_id: + offset_id += 1 + else: + offset_id = 1 + + if from_user: + from_user = await self.client.get_input_entity(from_user) + if not isinstance(from_user, ( + types.InputPeerUser, types.InputPeerSelf)): + from_user = None # Ignore from_user unless it's a user + + self.from_id = (await self.client.get_peer_id(from_user)) if from_user else None + + if filter is None: + filter = types.InputMessagesFilterEmpty() + + # Telegram completely ignores `from_id` in private chats + if isinstance(entity, (types.InputPeerUser, types.InputPeerSelf)): + # Don't bother sending `from_user` (it's ignored anyway), + # but keep `from_id` defined above to check it locally. + from_user = None + else: + # Do send `from_user` to do the filtering server-side, + # and set `from_id` to None to avoid checking it locally. + self.from_id = None + + self.request = functions.messages.SearchRequest( + peer=entity, + q=search or '', + filter=filter() if isinstance(filter, type) else filter, + min_date=None, + max_date=offset_date, + offset_id=offset_id, + add_offset=add_offset, + limit=0, # Search actually returns 0 items if we ask it to + max_id=0, + min_id=0, + hash=0, + from_id=from_user + ) + + if self.limit == 0: + # No messages, but we still need to know the total message count + result = await self.client(self.request) + if isinstance(result, types.messages.MessagesNotModified): + self.total = result.count + else: + self.total = getattr(result, 'count', len(result.messages)) + raise StopAsyncIteration + + # When going in reverse we need an offset of `-limit`, but we + # also want to respect what the user passed, so add them together. + if self.reverse: + self.request.add_offset -= batch_size + + if self.wait_time is None: + self.wait_time = 1 if self.limit > 3000 else 0 + + # Telegram has a hard limit of 100. + # We don't need to fetch 100 if the limit is less. + self.batch_size = min(max(batch_size, 1), min(100, self.limit)) + self.add_offset = add_offset + self.max_id = max_id + self.min_id = min_id + self.last_id = 0 if self.reverse else float('inf') + + async def _load_next_chunk(self): + result = [] + + self.request.limit = min(self.left, self.batch_size) + if self.reverse and self.request.limit != self.batch_size: + # Remember that we need -limit when going in reverse + self.request.add_offset = self.add_offset - self.request.limit + + r = await self.client(self.request) + self.total = getattr(r, 'count', len(r.messages)) + + entities = {utils.get_peer_id(x): x + for x in itertools.chain(r.users, r.chats)} + + messages = reversed(r.messages) if self.reverse else r.messages + for message in messages: + if (isinstance(message, types.MessageEmpty) + or self.from_id and message.from_id != self.from_id): + continue + + # TODO We used to yield and return here (stopping the iterator) + # How should we go around that here? + if self.reverse: + if message.id <= self.last_id or message.id >= self.max_id: + break + else: + if message.id >= self.last_id or message.id <= self.min_id: + break + + # There has been reports that on bad connections this method + # was returning duplicated IDs sometimes. Using ``last_id`` + # is an attempt to avoid these duplicates, since the message + # IDs are returned in descending order (or asc if reverse). + self.last_id = message.id + message._finish_init(self.client, entities, self.entity) + result.append(message) + + if len(r.messages) < self.request.limit: + return result + + # Find the first message that's not empty (in some rare cases + # it can happen that the last message is :tl:`MessageEmpty`) + last_message = None + messages = r.messages if self.reverse else reversed(r.messages) + for m in messages: + if not isinstance(m, types.MessageEmpty): + last_message = m + break + + # TODO If it's None, we used to break (ending the iterator) + # Similar case as the return above. + if last_message is not None: + # There are some cases where all the messages we get start + # being empty. This can happen on migrated mega-groups if + # the history was cleared, and we're using search. Telegram + # acts incredibly weird sometimes. Messages are returned but + # only "empty", not their contents. If this is the case we + # should just give up since there won't be any new Message. + self.request.offset_id = last_message.id + self.request.max_date = last_message.date # not offset_date + if self.reverse: + # We want to skip the one we already have + self.request.offset_id += 1 + + return result + + class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): # region Public methods @@ -272,22 +429,39 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): you think may be good. """ # TODO Handle global search - # TODO Handle search # TODO Handle yield IDs - return _GetHistoryIter( - self, - limit=limit, - wait_time=wait_time, - entity=entity, - reverse=reverse, - offset_id=offset_id, - min_id=min_id, - max_id=max_id, - from_user=from_user, - batch_size=batch_size, - offset_date=offset_date, - add_offset=add_offset - ) + # TODO Reuse code between search, global, get history + if search is not None or filter or from_user: + return _SearchMessagesIter( + self, + limit, + entity=entity, + offset_id=offset_id, + min_id=min_id, + max_id=max_id, + from_user=from_user, + batch_size=batch_size, + offset_date=offset_date, + add_offset=add_offset, + filter=filter, + search=search + ) + else: + return _GetHistoryIter( + self, + limit, + wait_time=wait_time, + entity=entity, + reverse=reverse, + offset_id=offset_id, + min_id=min_id, + max_id=max_id, + from_user=from_user, + batch_size=batch_size, + offset_date=offset_date, + add_offset=add_offset + ) + # Note that entity being ``None`` is intended to get messages by # ID under no specific chat, and also to request a global search. if entity: From e2f44ddbeaf38aebeb6abd0049c2c2fe10412a9c Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 27 Feb 2019 09:31:15 +0100 Subject: [PATCH 03/15] Make iter_messages use a common message iterator --- telethon/client/messages.py | 584 ++++++++++-------------------------- telethon/requestiter.py | 9 +- 2 files changed, 159 insertions(+), 434 deletions(-) diff --git a/telethon/client/messages.py b/telethon/client/messages.py index f4cb9cb7..e8d27961 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -1,6 +1,4 @@ -import asyncio import itertools -import time from .messageparse import MessageParseMethods from .uploads import UploadMethods @@ -10,9 +8,24 @@ from ..tl import types, functions from ..requestiter import RequestIter -class _GetHistoryIter(RequestIter): - async def _init(self, entity, offset_id, min_id, max_id, from_user, batch_size, offset_date, add_offset): - self.entity = await self.client.get_input_entity(entity) +# TODO Maybe RequestIter could rather have the update offset here? +# Maybe init should return the request to be used and it be +# called automatically? And another method to just process it. +class _MessagesIter(RequestIter): + """ + Common factor for all requests that need to iterate over messages. + """ + async def _init( + self, entity, offset_id, min_id, max_id, from_user, + batch_size, offset_date, add_offset, filter, search + ): + # Note that entity being ``None`` will perform a global search. + if entity: + self.entity = await self.client.get_input_entity(entity) + else: + self.entity = None + if self.reverse: + raise ValueError('Cannot reverse global search') # Telegram doesn't like min_id/max_id. If these IDs are low enough # (starting from last_id - 100), the request will return nothing. @@ -45,175 +58,59 @@ class _GetHistoryIter(RequestIter): types.InputPeerUser, types.InputPeerSelf)): from_user = None # Ignore from_user unless it's a user - self.from_id = (await self.client.get_peer_id(from_user)) if from_user else None - - self.request = functions.messages.GetHistoryRequest( - peer=entity, - limit=1, - offset_date=offset_date, - offset_id=offset_id, - min_id=0, - max_id=0, - add_offset=add_offset, - hash=0 - ) - - if self.limit == 0: - # No messages, but we still need to know the total message count - result = await self.client(self.request) - if isinstance(result, types.messages.MessagesNotModified): - self.total = result.count - else: - self.total = getattr(result, 'count', len(result.messages)) - raise StopAsyncIteration - - # When going in reverse we need an offset of `-limit`, but we - # also want to respect what the user passed, so add them together. - if self.reverse: - self.request.add_offset -= batch_size - - if self.wait_time is None: - self.wait_time = 1 if self.limit > 3000 else 0 - - # Telegram has a hard limit of 100. - # We don't need to fetch 100 if the limit is less. - self.batch_size = min(max(batch_size, 1), min(100, self.limit)) - self.add_offset = add_offset - self.max_id = max_id - self.min_id = min_id - self.last_id = 0 if self.reverse else float('inf') - - async def _load_next_chunk(self): - result = [] - - self.request.limit = min(self.left, self.batch_size) - if self.reverse and self.request.limit != self.batch_size: - # Remember that we need -limit when going in reverse - self.request.add_offset = self.add_offset - self.request.limit - - r = await self.client(self.request) - self.total = getattr(r, 'count', len(r.messages)) - - entities = {utils.get_peer_id(x): x - for x in itertools.chain(r.users, r.chats)} - - messages = reversed(r.messages) if self.reverse else r.messages - for message in messages: - if (isinstance(message, types.MessageEmpty) - or self.from_id and message.from_id != self.from_id): - continue - - # TODO We used to yield and return here (stopping the iterator) - # How should we go around that here? - if self.reverse: - if message.id <= self.last_id or message.id >= self.max_id: - break - else: - if message.id >= self.last_id or message.id <= self.min_id: - break - - # There has been reports that on bad connections this method - # was returning duplicated IDs sometimes. Using ``last_id`` - # is an attempt to avoid these duplicates, since the message - # IDs are returned in descending order (or asc if reverse). - self.last_id = message.id - message._finish_init(self.client, entities, self.entity) - result.append(message) - - if len(r.messages) < self.request.limit: - return result - - # Find the first message that's not empty (in some rare cases - # it can happen that the last message is :tl:`MessageEmpty`) - last_message = None - messages = r.messages if self.reverse else reversed(r.messages) - for m in messages: - if not isinstance(m, types.MessageEmpty): - last_message = m - break - - # TODO If it's None, we used to break (ending the iterator) - # Similar case as the return above. - if last_message is not None: - # There are some cases where all the messages we get start - # being empty. This can happen on migrated mega-groups if - # the history was cleared, and we're using search. Telegram - # acts incredibly weird sometimes. Messages are returned but - # only "empty", not their contents. If this is the case we - # should just give up since there won't be any new Message. - self.request.offset_id = last_message.id - self.request.offset_date = last_message.date - if self.reverse: - # We want to skip the one we already have - self.request.offset_id += 1 - - return result - - -class _SearchMessagesIter(RequestIter): - async def _init(self, entity, offset_id, min_id, max_id, from_user, batch_size, offset_date, add_offset, filter, search): - self.entity = await self.client.get_input_entity(entity) - - # Telegram doesn't like min_id/max_id. If these IDs are low enough - # (starting from last_id - 100), the request will return nothing. - # - # We can emulate their behaviour locally by setting offset = max_id - # and simply stopping once we hit a message with ID <= min_id. - if self.reverse: - offset_id = max(offset_id, min_id) - if offset_id and max_id: - if max_id - offset_id <= 1: - raise StopAsyncIteration - - if not max_id: - max_id = float('inf') - else: - offset_id = max(offset_id, max_id) - if offset_id and min_id: - if offset_id - min_id <= 1: - raise StopAsyncIteration - - if self.reverse: - if offset_id: - offset_id += 1 - else: - offset_id = 1 - if from_user: - from_user = await self.client.get_input_entity(from_user) - if not isinstance(from_user, ( - types.InputPeerUser, types.InputPeerSelf)): - from_user = None # Ignore from_user unless it's a user - - self.from_id = (await self.client.get_peer_id(from_user)) if from_user else None - - if filter is None: - filter = types.InputMessagesFilterEmpty() - - # Telegram completely ignores `from_id` in private chats - if isinstance(entity, (types.InputPeerUser, types.InputPeerSelf)): - # Don't bother sending `from_user` (it's ignored anyway), - # but keep `from_id` defined above to check it locally. - from_user = None + self.from_id = await self.client.get_peer_id(from_user) else: - # Do send `from_user` to do the filtering server-side, - # and set `from_id` to None to avoid checking it locally. self.from_id = None - self.request = functions.messages.SearchRequest( - peer=entity, - q=search or '', - filter=filter() if isinstance(filter, type) else filter, - min_date=None, - max_date=offset_date, - offset_id=offset_id, - add_offset=add_offset, - limit=0, # Search actually returns 0 items if we ask it to - max_id=0, - min_id=0, - hash=0, - from_id=from_user - ) + if not self.entity: + self.request = functions.messages.SearchGlobalRequest( + q=search or '', + offset_date=offset_date, + offset_peer=types.InputPeerEmpty(), + offset_id=offset_id, + limit=1 + ) + elif search is not None or filter or from_user: + if filter is None: + filter = types.InputMessagesFilterEmpty() + + # Telegram completely ignores `from_id` in private chats + if isinstance( + self.entity, (types.InputPeerUser, types.InputPeerSelf)): + # Don't bother sending `from_user` (it's ignored anyway), + # but keep `from_id` defined above to check it locally. + from_user = None + else: + # Do send `from_user` to do the filtering server-side, + # and set `from_id` to None to avoid checking it locally. + self.from_id = None + + self.request = functions.messages.SearchRequest( + peer=self.entity, + q=search or '', + filter=filter() if isinstance(filter, type) else filter, + min_date=None, + max_date=offset_date, + offset_id=offset_id, + add_offset=add_offset, + limit=0, # Search actually returns 0 items if we ask it to + max_id=0, + min_id=0, + hash=0, + from_id=from_user + ) + else: + self.request = functions.messages.GetHistoryRequest( + peer=self.entity, + limit=1, + offset_date=offset_date, + offset_id=offset_id, + min_id=0, + max_id=0, + add_offset=add_offset, + hash=0 + ) if self.limit == 0: # No messages, but we still need to know the total message count @@ -227,19 +124,20 @@ class _SearchMessagesIter(RequestIter): # When going in reverse we need an offset of `-limit`, but we # also want to respect what the user passed, so add them together. if self.reverse: - self.request.add_offset -= batch_size + self.request.add_offset -= self.batch_size if self.wait_time is None: self.wait_time = 1 if self.limit > 3000 else 0 - # Telegram has a hard limit of 100. - # We don't need to fetch 100 if the limit is less. - self.batch_size = min(max(batch_size, 1), min(100, self.limit)) self.add_offset = add_offset self.max_id = max_id self.min_id = min_id self.last_id = 0 if self.reverse else float('inf') + # Telegram has a hard limit of 100. + # We don't need to fetch 100 if the limit is less. + self.batch_size = min(max(batch_size, 1), min(100, self.limit)) + async def _load_next_chunk(self): result = [] @@ -260,14 +158,9 @@ class _SearchMessagesIter(RequestIter): or self.from_id and message.from_id != self.from_id): continue - # TODO We used to yield and return here (stopping the iterator) - # How should we go around that here? - if self.reverse: - if message.id <= self.last_id or message.id >= self.max_id: - break - else: - if message.id >= self.last_id or message.id <= self.min_id: - break + if not self._message_in_range(message): + self.left = len(result) + break # There has been reports that on bad connections this method # was returning duplicated IDs sometimes. Using ``last_id`` @@ -278,34 +171,74 @@ class _SearchMessagesIter(RequestIter): result.append(message) if len(r.messages) < self.request.limit: - return result + self.left = len(result) - # Find the first message that's not empty (in some rare cases + # Get the first message that's not empty (in some rare cases # it can happen that the last message is :tl:`MessageEmpty`) - last_message = None - messages = r.messages if self.reverse else reversed(r.messages) - for m in messages: - if not isinstance(m, types.MessageEmpty): - last_message = m - break - - # TODO If it's None, we used to break (ending the iterator) - # Similar case as the return above. - if last_message is not None: + if result: + self._update_offset(result[0]) + else: # There are some cases where all the messages we get start # being empty. This can happen on migrated mega-groups if # the history was cleared, and we're using search. Telegram # acts incredibly weird sometimes. Messages are returned but # only "empty", not their contents. If this is the case we # should just give up since there won't be any new Message. - self.request.offset_id = last_message.id - self.request.max_date = last_message.date # not offset_date - if self.reverse: - # We want to skip the one we already have - self.request.offset_id += 1 + self.left = len(result) return result + def _message_in_range(self, message): + """ + Determine whether the given message is in the range or + it should be ignored (and avoid loading more chunks). + """ + # No entity means message IDs between chats may vary + if self.entity: + if self.reverse: + if message.id <= self.last_id or message.id >= self.max_id: + return False + else: + if message.id >= self.last_id or message.id <= self.min_id: + return False + + return True + + def _update_offset(self, last_message): + """ + After making the request, update its offset with the last message. + """ + self.request.offset_id = last_message.id + if self.reverse: + # We want to skip the one we already have + self.request.offset_id += 1 + + if isinstance(self.request, functions.messages.SearchRequest): + self.request.max_date = last_message.date + else: + # getHistory and searchGlobal call it offset_date + self.request.offset_date = last_message.date + + if isinstance(self.request, functions.messages.SearchGlobalRequest): + self.request.offset_peer = last_message.input_chat + + +class _IDsIter(RequestIter): + async def _init(self, entity, from_user, ids): + if not utils.is_list_like(ids): + self.ids = [ids] + elif not ids: + raise StopAsyncIteration + elif self.reverse: + self.ids = list(reversed(ids)) + else: + self.ids = ids + + raise NotImplementedError + + async def _load_next_chunk(self): + raise NotImplementedError + class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): @@ -428,242 +361,26 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): an higher limit, so you're free to set the ``batch_size`` that you think may be good. """ - # TODO Handle global search - # TODO Handle yield IDs - # TODO Reuse code between search, global, get history - if search is not None or filter or from_user: - return _SearchMessagesIter( - self, - limit, - entity=entity, - offset_id=offset_id, - min_id=min_id, - max_id=max_id, - from_user=from_user, - batch_size=batch_size, - offset_date=offset_date, - add_offset=add_offset, - filter=filter, - search=search - ) - else: - return _GetHistoryIter( - self, - limit, - wait_time=wait_time, - entity=entity, - reverse=reverse, - offset_id=offset_id, - min_id=min_id, - max_id=max_id, - from_user=from_user, - batch_size=batch_size, - offset_date=offset_date, - add_offset=add_offset - ) - # Note that entity being ``None`` is intended to get messages by - # ID under no specific chat, and also to request a global search. - if entity: - entity = await self.get_input_entity(entity) + if ids is not None: + return _IDsIter(self, limit, entity=entity, ids=ids) - if ids: - if not utils.is_list_like(ids): - ids = (ids,) - if reverse: - ids = list(reversed(ids)) - async for x in self._iter_ids(entity, ids, total=_total): - await yield_(x) - return - - # Telegram doesn't like min_id/max_id. If these IDs are low enough - # (starting from last_id - 100), the request will return nothing. - # - # We can emulate their behaviour locally by setting offset = max_id - # and simply stopping once we hit a message with ID <= min_id. - if reverse: - offset_id = max(offset_id, min_id) - if offset_id and max_id: - if max_id - offset_id <= 1: - return - - if not max_id: - max_id = float('inf') - else: - offset_id = max(offset_id, max_id) - if offset_id and min_id: - if offset_id - min_id <= 1: - return - - if reverse: - if offset_id: - offset_id += 1 - else: - offset_id = 1 - - if from_user: - from_user = await self.get_input_entity(from_user) - if not isinstance(from_user, ( - types.InputPeerUser, types.InputPeerSelf)): - from_user = None # Ignore from_user unless it's a user - - from_id = (await self.get_peer_id(from_user)) if from_user else None - - limit = float('inf') if limit is None else int(limit) - if not entity: - if reverse: - raise ValueError('Cannot reverse global search') - - reverse = None - request = functions.messages.SearchGlobalRequest( - q=search or '', - offset_date=offset_date, - offset_peer=types.InputPeerEmpty(), - offset_id=offset_id, - limit=1 - ) - elif search is not None or filter or from_user: - if filter is None: - filter = types.InputMessagesFilterEmpty() - - # Telegram completely ignores `from_id` in private chats - if isinstance(entity, (types.InputPeerUser, types.InputPeerSelf)): - # Don't bother sending `from_user` (it's ignored anyway), - # but keep `from_id` defined above to check it locally. - from_user = None - else: - # Do send `from_user` to do the filtering server-side, - # and set `from_id` to None to avoid checking it locally. - from_id = None - - request = functions.messages.SearchRequest( - peer=entity, - q=search or '', - filter=filter() if isinstance(filter, type) else filter, - min_date=None, - max_date=offset_date, - offset_id=offset_id, - add_offset=add_offset, - limit=0, # Search actually returns 0 items if we ask it to - max_id=0, - min_id=0, - hash=0, - from_id=from_user - ) - else: - request = functions.messages.GetHistoryRequest( - peer=entity, - limit=1, - offset_date=offset_date, - offset_id=offset_id, - min_id=0, - max_id=0, - add_offset=add_offset, - hash=0 - ) - - if limit == 0: - if not _total: - return - # No messages, but we still need to know the total message count - result = await self(request) - if isinstance(result, types.messages.MessagesNotModified): - _total[0] = result.count - else: - _total[0] = getattr(result, 'count', len(result.messages)) - return - - if wait_time is None: - wait_time = 1 if limit > 3000 else 0 - - have = 0 - last_id = 0 if reverse else float('inf') - - # Telegram has a hard limit of 100. - # We don't need to fetch 100 if the limit is less. - batch_size = min(max(batch_size, 1), min(100, limit)) - - # When going in reverse we need an offset of `-limit`, but we - # also want to respect what the user passed, so add them together. - if reverse: - request.add_offset -= batch_size - - while have < limit: - start = time.time() - - request.limit = min(limit - have, batch_size) - if reverse and request.limit != batch_size: - # Remember that we need -limit when going in reverse - request.add_offset = add_offset - request.limit - - r = await self(request) - if _total: - _total[0] = getattr(r, 'count', len(r.messages)) - - entities = {utils.get_peer_id(x): x - for x in itertools.chain(r.users, r.chats)} - - messages = reversed(r.messages) if reverse else r.messages - for message in messages: - if (isinstance(message, types.MessageEmpty) - or from_id and message.from_id != from_id): - continue - - if reverse is None: - pass - elif reverse: - if message.id <= last_id or message.id >= max_id: - return - else: - if message.id >= last_id or message.id <= min_id: - return - - # There has been reports that on bad connections this method - # was returning duplicated IDs sometimes. Using ``last_id`` - # is an attempt to avoid these duplicates, since the message - # IDs are returned in descending order (or asc if reverse). - last_id = message.id - - message._finish_init(self, entities, entity) - await yield_(message) - have += 1 - - if len(r.messages) < request.limit: - break - - # Find the first message that's not empty (in some rare cases - # it can happen that the last message is :tl:`MessageEmpty`) - last_message = None - messages = r.messages if reverse else reversed(r.messages) - for m in messages: - if not isinstance(m, types.MessageEmpty): - last_message = m - break - - if last_message is None: - # There are some cases where all the messages we get start - # being empty. This can happen on migrated mega-groups if - # the history was cleared, and we're using search. Telegram - # acts incredibly weird sometimes. Messages are returned but - # only "empty", not their contents. If this is the case we - # should just give up since there won't be any new Message. - break - else: - request.offset_id = last_message.id - if isinstance(request, functions.messages.SearchRequest): - request.max_date = last_message.date - else: - # getHistory and searchGlobal call it offset_date - request.offset_date = last_message.date - - if isinstance(request, functions.messages.SearchGlobalRequest): - request.offset_peer = last_message.input_chat - elif reverse: - # We want to skip the one we already have - request.offset_id += 1 - - await asyncio.sleep( - max(wait_time - (time.time() - start), 0), loop=self._loop) + return _MessagesIter( + client=self, + reverse=reverse, + wait_time=wait_time, + limit=limit, + entity=entity, + offset_id=offset_id, + min_id=min_id, + max_id=max_id, + from_user=from_user, + batch_size=batch_size, + offset_date=offset_date, + add_offset=add_offset, + filter=filter, + search=search + ) async def get_messages(self, *args, **kwargs): """ @@ -682,6 +399,7 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): a single `Message ` will be returned for convenience instead of a list. """ + # TODO Make RequestIter have a .collect() or similar total = [0] kwargs['_total'] = total if len(args) == 1 and 'limit' not in kwargs: diff --git a/telethon/requestiter.py b/telethon/requestiter.py index d8ccff53..ba897a2a 100644 --- a/telethon/requestiter.py +++ b/telethon/requestiter.py @@ -3,6 +3,10 @@ import asyncio import time +# TODO There are two types of iterators for requests. +# One has a limit of items to retrieve, and the +# other has a list that must be called in chunks. +# Make classes for both here so it's easy to use. class RequestIter(abc.ABC): """ Helper class to deal with requests that need offsets to iterate. @@ -50,6 +54,9 @@ class RequestIter(abc.ABC): if self.buffer is (): await self._init(**self.kwargs) + if self.left <= 0: # <= 0 because subclasses may change it + raise StopAsyncIteration + if self.index == len(self.buffer): # asyncio will handle times <= 0 to sleep 0 seconds if self.wait_time: @@ -84,7 +91,7 @@ class RequestIter(abc.ABC): 'is running (i.e. you are inside an "async def")' ) - raise NotImplementedError('lol!') + return self.__aiter__() @abc.abstractmethod async def _load_next_chunk(self): From f765f73fa372b3c176d32c4411cd7534be3df48f Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 27 Feb 2019 09:32:33 +0100 Subject: [PATCH 04/15] Fix setting batch size --- telethon/client/messages.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/telethon/client/messages.py b/telethon/client/messages.py index e8d27961..2ace90ed 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -121,23 +121,23 @@ class _MessagesIter(RequestIter): self.total = getattr(result, 'count', len(result.messages)) raise StopAsyncIteration + if self.wait_time is None: + self.wait_time = 1 if self.limit > 3000 else 0 + + # Telegram has a hard limit of 100. + # We don't need to fetch 100 if the limit is less. + self.batch_size = min(max(batch_size, 1), min(100, self.limit)) + # When going in reverse we need an offset of `-limit`, but we # also want to respect what the user passed, so add them together. if self.reverse: self.request.add_offset -= self.batch_size - if self.wait_time is None: - self.wait_time = 1 if self.limit > 3000 else 0 - self.add_offset = add_offset self.max_id = max_id self.min_id = min_id self.last_id = 0 if self.reverse else float('inf') - # Telegram has a hard limit of 100. - # We don't need to fetch 100 if the limit is less. - self.batch_size = min(max(batch_size, 1), min(100, self.limit)) - async def _load_next_chunk(self): result = [] From e3991fadd583aea31b5da8b5ef8210da87b52a91 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 27 Feb 2019 09:37:12 +0100 Subject: [PATCH 05/15] Fix updating offset --- telethon/client/messages.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/telethon/client/messages.py b/telethon/client/messages.py index 2ace90ed..7dc81ae7 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -173,10 +173,10 @@ class _MessagesIter(RequestIter): if len(r.messages) < self.request.limit: self.left = len(result) - # Get the first message that's not empty (in some rare cases + # Get the last message that's not empty (in some rare cases # it can happen that the last message is :tl:`MessageEmpty`) if result: - self._update_offset(result[0]) + self._update_offset(result[-1]) else: # There are some cases where all the messages we get start # being empty. This can happen on migrated mega-groups if From 35dc46ffb0fb3202fe21315b380339943aad6974 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 27 Feb 2019 09:48:47 +0100 Subject: [PATCH 06/15] Fix searching messages in reverse --- telethon/client/messages.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/telethon/client/messages.py b/telethon/client/messages.py index 7dc81ae7..f14d7398 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -214,7 +214,11 @@ class _MessagesIter(RequestIter): self.request.offset_id += 1 if isinstance(self.request, functions.messages.SearchRequest): - self.request.max_date = last_message.date + # Unlike getHistory and searchGlobal that use *offset* date, + # this is *max* date. This means that doing a search in reverse + # will break it. Since it's not really needed once we're going + # (only for the first request), it's safe to just clear it off. + self.request.max_date = None else: # getHistory and searchGlobal call it offset_date self.request.offset_date = last_message.date From 60606b999494bc189dc42240cdabc968d5d5f1f8 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 27 Feb 2019 09:49:14 +0100 Subject: [PATCH 07/15] Don't make iter_messages a coroutine function --- telethon/client/messages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/telethon/client/messages.py b/telethon/client/messages.py index f14d7398..1181b14f 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -250,7 +250,7 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): # region Message retrieval - async def iter_messages( + def iter_messages( self, entity, limit=None, *, offset_date=None, offset_id=0, max_id=0, min_id=0, add_offset=0, search=None, filter=None, from_user=None, batch_size=100, wait_time=None, ids=None, From 6d6c1917bcf28d7ead91ddf34e179c27bd709ccb Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 27 Feb 2019 10:04:12 +0100 Subject: [PATCH 08/15] Implement iterator over message by IDs --- telethon/client/messages.py | 97 ++++++++++++++++++------------------- telethon/requestiter.py | 2 + 2 files changed, 48 insertions(+), 51 deletions(-) diff --git a/telethon/client/messages.py b/telethon/client/messages.py index 1181b14f..9750003d 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -228,7 +228,8 @@ class _MessagesIter(RequestIter): class _IDsIter(RequestIter): - async def _init(self, entity, from_user, ids): + async def _init(self, entity, ids): + # TODO We never actually split IDs in chunks, but maybe we should if not utils.is_list_like(ids): self.ids = [ids] elif not ids: @@ -238,10 +239,52 @@ class _IDsIter(RequestIter): else: self.ids = ids - raise NotImplementedError + if entity: + entity = await self.client.get_input_entity(entity) + + self.total = len(ids) + + from_id = None # By default, no need to validate from_id + if isinstance(entity, (types.InputChannel, types.InputPeerChannel)): + try: + r = await self.client( + functions.channels.GetMessagesRequest(entity, ids)) + except errors.MessageIdsEmptyError: + # All IDs were invalid, use a dummy result + r = types.messages.MessagesNotModified(len(ids)) + else: + r = await self.client(functions.messages.GetMessagesRequest(ids)) + if entity: + from_id = await self.client.get_peer_id(entity) + + if isinstance(r, types.messages.MessagesNotModified): + self.buffer = [None] * len(ids) + return + + entities = {utils.get_peer_id(x): x + for x in itertools.chain(r.users, r.chats)} + + # Telegram seems to return the messages in the order in which + # we asked them for, so we don't need to check it ourselves, + # unless some messages were invalid in which case Telegram + # may decide to not send them at all. + # + # The passed message IDs may not belong to the desired entity + # since the user can enter arbitrary numbers which can belong to + # arbitrary chats. Validate these unless ``from_id is None``. + result = [] + for message in r.messages: + if isinstance(message, types.MessageEmpty) or ( + from_id and message.chat_id != from_id): + result.append(None) + else: + message._finish_init(self.client, entities, entity) + result.append(message) + + self.buffer = result async def _load_next_chunk(self): - raise NotImplementedError + return [] # no next chunk, all done in init class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): @@ -850,51 +893,3 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): # endregion # endregion - - # region Private methods - - async def _iter_ids(self, entity, ids, total): - """ - Special case for `iter_messages` when it should only fetch some IDs. - """ - if total: - total[0] = len(ids) - - from_id = None # By default, no need to validate from_id - if isinstance(entity, (types.InputChannel, types.InputPeerChannel)): - try: - r = await self( - functions.channels.GetMessagesRequest(entity, ids)) - except errors.MessageIdsEmptyError: - # All IDs were invalid, use a dummy result - r = types.messages.MessagesNotModified(len(ids)) - else: - r = await self(functions.messages.GetMessagesRequest(ids)) - if entity: - from_id = utils.get_peer_id(entity) - - if isinstance(r, types.messages.MessagesNotModified): - for _ in ids: - await yield_(None) - return - - entities = {utils.get_peer_id(x): x - for x in itertools.chain(r.users, r.chats)} - - # Telegram seems to return the messages in the order in which - # we asked them for, so we don't need to check it ourselves, - # unless some messages were invalid in which case Telegram - # may decide to not send them at all. - # - # The passed message IDs may not belong to the desired entity - # since the user can enter arbitrary numbers which can belong to - # arbitrary chats. Validate these unless ``from_id is None``. - for message in r.messages: - if isinstance(message, types.MessageEmpty) or ( - from_id and message.chat_id != from_id): - await yield_(None) - else: - message._finish_init(self, entities, entity) - await yield_(message) - - # endregion diff --git a/telethon/requestiter.py b/telethon/requestiter.py index ba897a2a..98a1cfb6 100644 --- a/telethon/requestiter.py +++ b/telethon/requestiter.py @@ -48,6 +48,8 @@ class RequestIter(abc.ABC): to avoid forgetting or misspelling any of them. This method may ``raise StopAsyncIteration`` if it cannot continue. + + This method may actually fill the initial buffer if it needs to. """ async def __anext__(self): From 5b8e6531fa0b8f1f95c318dc5008b52b7f9334c9 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 27 Feb 2019 10:15:32 +0100 Subject: [PATCH 09/15] Add method to collect RequestIter into TotalList --- telethon/client/messages.py | 21 ++++++++++----------- telethon/requestiter.py | 13 +++++++++++++ 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/telethon/client/messages.py b/telethon/client/messages.py index 9750003d..87d49be6 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -446,24 +446,23 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): a single `Message ` will be returned for convenience instead of a list. """ - # TODO Make RequestIter have a .collect() or similar - total = [0] - kwargs['_total'] = total if len(args) == 1 and 'limit' not in kwargs: if 'min_id' in kwargs and 'max_id' in kwargs: kwargs['limit'] = None else: kwargs['limit'] = 1 - msgs = helpers.TotalList() - async for x in self.iter_messages(*args, **kwargs): - msgs.append(x) - msgs.total = total[0] - if 'ids' in kwargs and not utils.is_list_like(kwargs['ids']): - # Check for empty list to handle InputMessageReplyTo - return msgs[0] if msgs else None + it = self.iter_messages(*args, **kwargs) - return msgs + ids = kwargs.get('ids') + if ids and not utils.is_list_like(ids): + async for message in it: + return message + else: + # Iterator exhausted = empty, to handle InputMessageReplyTo + return None + + return await it.collect() # endregion diff --git a/telethon/requestiter.py b/telethon/requestiter.py index 98a1cfb6..af632389 100644 --- a/telethon/requestiter.py +++ b/telethon/requestiter.py @@ -2,6 +2,8 @@ import abc import asyncio import time +from . import helpers + # TODO There are two types of iterators for requests. # One has a limit of items to retrieve, and the @@ -95,6 +97,17 @@ class RequestIter(abc.ABC): return self.__aiter__() + async def collect(self): + """ + Create a `self` iterator and collect it into a `TotalList` + (a normal list with a `.total` attribute). + """ + result = helpers.TotalList() + async for message in self: + result.append(message) + + return result + @abc.abstractmethod async def _load_next_chunk(self): """ From 49d8a3fb331580e7c86a740d233ef662a4b77104 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 27 Feb 2019 10:19:08 +0100 Subject: [PATCH 10/15] Remove code to syncify async generator functions --- telethon/sync.py | 28 +++++----------------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/telethon/sync.py b/telethon/sync.py index 925b2046..c5f0f623 100644 --- a/telethon/sync.py +++ b/telethon/sync.py @@ -14,8 +14,6 @@ import asyncio import functools import inspect -from async_generator import isasyncgenfunction - from .client.telegramclient import TelegramClient from .tl.custom import ( Draft, Dialog, MessageButton, Forward, Message, InlineResult, Conversation @@ -24,22 +22,7 @@ from .tl.custom.chatgetter import ChatGetter from .tl.custom.sendergetter import SenderGetter -class _SyncGen: - def __init__(self, gen): - self.gen = gen - - def __iter__(self): - return self - - def __next__(self): - try: - return asyncio.get_event_loop() \ - .run_until_complete(self.gen.__anext__()) - except StopAsyncIteration: - raise StopIteration from None - - -def _syncify_wrap(t, method_name, gen): +def _syncify_wrap(t, method_name): method = getattr(t, method_name) @functools.wraps(method) @@ -48,8 +31,6 @@ def _syncify_wrap(t, method_name, gen): loop = asyncio.get_event_loop() if loop.is_running(): return coro - elif gen: - return _SyncGen(coro) else: return loop.run_until_complete(coro) @@ -64,13 +45,14 @@ def syncify(*types): into synchronous, which return either the coroutine or the result based on whether ``asyncio's`` event loop is running. """ + # Our asynchronous generators all are `RequestIter`, which already + # provide a synchronous iterator variant, so we don't need to worry + # about asyncgenfunction's here. for t in types: for name in dir(t): if not name.startswith('_') or name == '__call__': if inspect.iscoroutinefunction(getattr(t, name)): - _syncify_wrap(t, name, gen=False) - elif isasyncgenfunction(getattr(t, name)): - _syncify_wrap(t, name, gen=True) + _syncify_wrap(t, name) syncify(TelegramClient, Draft, Dialog, MessageButton, From 968da5f72dfd3e9b114733b80ebe3275f918da15 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 27 Feb 2019 10:37:40 +0100 Subject: [PATCH 11/15] Use RequestIter in the dialog methods --- telethon/client/dialogs.py | 195 +++++++++++++++++++------------------ 1 file changed, 101 insertions(+), 94 deletions(-) diff --git a/telethon/client/dialogs.py b/telethon/client/dialogs.py index 6c7edd24..bf3f903f 100644 --- a/telethon/client/dialogs.py +++ b/telethon/client/dialogs.py @@ -1,18 +1,105 @@ import itertools -from async_generator import async_generator, yield_ - from .users import UserMethods -from .. import utils, helpers +from .. import utils +from ..requestiter import RequestIter from ..tl import types, functions, custom +class _DialogsIter(RequestIter): + async def _init( + self, offset_date, offset_id, offset_peer, ignore_migrated + ): + self.request = functions.messages.GetDialogsRequest( + offset_date=offset_date, + offset_id=offset_id, + offset_peer=offset_peer, + limit=1, + hash=0 + ) + + if self.limit == 0: + # Special case, get a single dialog and determine count + dialogs = await self.client(self.request) + self.total = getattr(dialogs, 'count', len(dialogs.dialogs)) + raise StopAsyncIteration + + self.seen = set() + self.offset_date = offset_date + self.ignore_migrated = ignore_migrated + + async def _load_next_chunk(self): + result = [] + + self.request.limit = min(self.left, 100) + r = await self.client(self.request) + + self.total = getattr(r, 'count', len(r.dialogs)) + + entities = {utils.get_peer_id(x): x + for x in itertools.chain(r.users, r.chats)} + + messages = {} + for m in r.messages: + m._finish_init(self, entities, None) + messages[m.id] = m + + for d in r.dialogs: + # We check the offset date here because Telegram may ignore it + if self.offset_date: + date = getattr(messages.get( + d.top_message, None), 'date', None) + + if not date or date.timestamp() > self.offset_date.timestamp(): + continue + + peer_id = utils.get_peer_id(d.peer) + if peer_id not in self.seen: + self.seen.add(peer_id) + cd = custom.Dialog(self, d, entities, messages) + if cd.dialog.pts: + self.client._channel_pts[cd.id] = cd.dialog.pts + + if not self.ignore_migrated or getattr( + cd.entity, 'migrated_to', None) is None: + result.append(cd) + + if len(r.dialogs) < self.request.limit\ + or not isinstance(r, types.messages.DialogsSlice): + # Less than we requested means we reached the end, or + # we didn't get a DialogsSlice which means we got all. + self.left = len(result) + + self.request.offset_date = r.messages[-1].date + self.request.offset_peer =\ + entities[utils.get_peer_id(r.dialogs[-1].peer)] + + if self.request.offset_id == r.messages[-1].id: + # In some very rare cases this will get stuck in an infinite + # loop, where the offsets will get reused over and over. If + # the new offset is the same as the one before, break already. + self.left = len(result) + + self.request.offset_id = r.messages[-1].id + self.request.exclude_pinned = True + return result + + +class _DraftsIter(RequestIter): + async def _init(self, **kwargs): + r = await self.client(functions.messages.GetAllDraftsRequest()) + self.buffer = [custom.Draft._from_update(self.client, u) + for u in r.updates] + + async def _load_next_chunk(self): + return [] + + class DialogMethods(UserMethods): # region Public methods - @async_generator - async def iter_dialogs( + def iter_dialogs( self, limit=None, *, offset_date=None, offset_id=0, offset_peer=types.InputPeerEmpty(), ignore_migrated=False, _total=None): @@ -50,99 +137,23 @@ class DialogMethods(UserMethods): Yields: Instances of `telethon.tl.custom.dialog.Dialog`. """ - limit = float('inf') if limit is None else int(limit) - if limit == 0: - if not _total: - return - # Special case, get a single dialog and determine count - dialogs = await self(functions.messages.GetDialogsRequest( - offset_date=offset_date, - offset_id=offset_id, - offset_peer=offset_peer, - limit=1, - hash=0 - )) - _total[0] = getattr(dialogs, 'count', len(dialogs.dialogs)) - return - - seen = set() - req = functions.messages.GetDialogsRequest( + return _DialogsIter( + self, + limit, offset_date=offset_date, offset_id=offset_id, offset_peer=offset_peer, - limit=0, - hash=0 + ignore_migrated=ignore_migrated ) - while len(seen) < limit: - req.limit = min(limit - len(seen), 100) - r = await self(req) - - if _total: - _total[0] = getattr(r, 'count', len(r.dialogs)) - - entities = {utils.get_peer_id(x): x - for x in itertools.chain(r.users, r.chats)} - - messages = {} - for m in r.messages: - m._finish_init(self, entities, None) - messages[m.id] = m - - # Happens when there are pinned dialogs - if len(r.dialogs) > limit: - r.dialogs = r.dialogs[:limit] - - for d in r.dialogs: - if offset_date: - date = getattr(messages.get( - d.top_message, None), 'date', None) - - if not date or date.timestamp() > offset_date.timestamp(): - continue - - peer_id = utils.get_peer_id(d.peer) - if peer_id not in seen: - seen.add(peer_id) - cd = custom.Dialog(self, d, entities, messages) - if cd.dialog.pts: - self._channel_pts[cd.id] = cd.dialog.pts - - if not ignore_migrated or getattr( - cd.entity, 'migrated_to', None) is None: - await yield_(cd) - - if len(r.dialogs) < req.limit\ - or not isinstance(r, types.messages.DialogsSlice): - # Less than we requested means we reached the end, or - # we didn't get a DialogsSlice which means we got all. - break - - req.offset_date = r.messages[-1].date - req.offset_peer = entities[utils.get_peer_id(r.dialogs[-1].peer)] - if req.offset_id == r.messages[-1].id: - # In some very rare cases this will get stuck in an infinite - # loop, where the offsets will get reused over and over. If - # the new offset is the same as the one before, break already. - break - - req.offset_id = r.messages[-1].id - req.exclude_pinned = True async def get_dialogs(self, *args, **kwargs): """ Same as `iter_dialogs`, but returns a `TotalList ` instead. """ - total = [0] - kwargs['_total'] = total - dialogs = helpers.TotalList() - async for x in self.iter_dialogs(*args, **kwargs): - dialogs.append(x) - dialogs.total = total[0] - return dialogs + return await self.iter_dialogs(*args, **kwargs).collect() - @async_generator - async def iter_drafts(self): + def iter_drafts(self): """ Iterator over all open draft messages. @@ -151,18 +162,14 @@ class DialogMethods(UserMethods): to change the message or `telethon.tl.custom.draft.Draft.delete` among other things. """ - r = await self(functions.messages.GetAllDraftsRequest()) - for update in r.updates: - await yield_(custom.Draft._from_update(self, update)) + # TODO Passing a limit here makes no sense + return _DraftsIter(self, None) async def get_drafts(self): """ Same as :meth:`iter_drafts`, but returns a list instead. """ - result = [] - async for x in self.iter_drafts(): - result.append(x) - return result + return await self.iter_drafts().collect() def conversation( self, entity, From 4f647847e70b31b7020b7c7c475834d7bc6b5479 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 27 Feb 2019 10:39:56 +0100 Subject: [PATCH 12/15] Fix RequestIter not setting TotalList.total in collect() --- telethon/requestiter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/telethon/requestiter.py b/telethon/requestiter.py index af632389..ca9c8d72 100644 --- a/telethon/requestiter.py +++ b/telethon/requestiter.py @@ -106,6 +106,7 @@ class RequestIter(abc.ABC): async for message in self: result.append(message) + result.total = self.total return result @abc.abstractmethod From 40ded93c7c4bfe9aa377a6f3b4ca90e64bf11d28 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 27 Feb 2019 11:12:05 +0100 Subject: [PATCH 13/15] Use RequestIter in chat methods --- telethon/client/chats.py | 408 +++++++++++++++++++++------------------ 1 file changed, 222 insertions(+), 186 deletions(-) diff --git a/telethon/client/chats.py b/telethon/client/chats.py index 6578c785..7afe4655 100644 --- a/telethon/client/chats.py +++ b/telethon/client/chats.py @@ -1,19 +1,202 @@ import itertools -import sys - -from async_generator import async_generator, yield_ from .users import UserMethods -from .. import utils, helpers +from .. import utils +from ..requestiter import RequestIter from ..tl import types, functions, custom +class _ParticipantsIter(RequestIter): + async def _init(self, entity, filter, search, aggressive): + if isinstance(filter, type): + if filter in (types.ChannelParticipantsBanned, + types.ChannelParticipantsKicked, + types.ChannelParticipantsSearch, + types.ChannelParticipantsContacts): + # These require a `q` parameter (support types for convenience) + filter = filter('') + else: + filter = filter() + + entity = await self.client.get_input_entity(entity) + if search and (filter + or not isinstance(entity, types.InputPeerChannel)): + # We need to 'search' ourselves unless we have a PeerChannel + search = search.lower() + + self.filter_entity = lambda ent: ( + search in utils.get_display_name(ent).lower() or + search in (getattr(ent, 'username', '') or None).lower() + ) + else: + self.filter_entity = lambda ent: True + + if isinstance(entity, types.InputPeerChannel): + self.total = (await self.client( + functions.channels.GetFullChannelRequest(entity) + )).full_chat.participants_count + + if self.limit == 0: + raise StopAsyncIteration + + self.seen = set() + if aggressive and not filter: + self.requests = [functions.channels.GetParticipantsRequest( + channel=entity, + filter=types.ChannelParticipantsSearch(x), + offset=0, + limit=200, + hash=0 + ) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))] + else: + self.requests = [functions.channels.GetParticipantsRequest( + channel=entity, + filter=filter or types.ChannelParticipantsSearch(search), + offset=0, + limit=200, + hash=0 + )] + + elif isinstance(entity, types.InputPeerChat): + full = await self.client( + functions.messages.GetFullChatRequest(entity.chat_id)) + if not isinstance( + full.full_chat.participants, types.ChatParticipants): + # ChatParticipantsForbidden won't have ``.participants`` + self.total = 0 + raise StopAsyncIteration + + self.total = len(full.full_chat.participants.participants) + + result = [] + users = {user.id: user for user in full.users} + for participant in full.full_chat.participants.participants: + user = users[participant.user_id] + if not self.filter_entity(user): + continue + + user = users[participant.user_id] + user.participant = participant + result.append(user) + + self.left = len(result) + self.buffer = result + else: + result = [] + self.total = 1 + if self.limit != 0: + user = await self.client.get_entity(entity) + if self.filter_entity(user): + user.participant = None + result.append(user) + + self.left = len(result) + self.buffer = result + + async def _load_next_chunk(self): + result = [] + if not self.requests: + return result + + # Only care about the limit for the first request + # (small amount of people, won't be aggressive). + # + # Most people won't care about getting exactly 12,345 + # members so it doesn't really matter not to be 100% + # precise with being out of the offset/limit here. + self.requests[0].limit = min(self.limit - self.requests[0].offset, 200) + if self.requests[0].offset > self.limit: + return result + + results = await self.client(self.requests) + for i in reversed(range(len(self.requests))): + participants = results[i] + if not participants.users: + self.requests.pop(i) + continue + + self.requests[i].offset += len(participants.participants) + users = {user.id: user for user in participants.users} + for participant in participants.participants: + user = users[participant.user_id] + if not self.filter_entity(user) or user.id in self.seen: + continue + + self.seen.add(participant.user_id) + user = users[participant.user_id] + user.participant = participant + result.append(user) + + return result + + +class _AdminLogIter(RequestIter): + async def _init( + self, entity, admins, search, min_id, max_id, + join, leave, invite, restrict, unrestrict, ban, unban, + promote, demote, info, settings, pinned, edit, delete + ): + if any((join, leave, invite, restrict, unrestrict, ban, unban, + promote, demote, info, settings, pinned, edit, delete)): + events_filter = types.ChannelAdminLogEventsFilter( + join=join, leave=leave, invite=invite, ban=restrict, + unban=unrestrict, kick=ban, unkick=unban, promote=promote, + demote=demote, info=info, settings=settings, pinned=pinned, + edit=edit, delete=delete + ) + else: + events_filter = None + + self.entity = await self.client.get_input_entity(entity) + + admin_list = [] + if admins: + if not utils.is_list_like(admins): + admins = (admins,) + + for admin in admins: + admin_list.append(await self.client.get_input_entity(admin)) + + self.request = functions.channels.GetAdminLogRequest( + self.entity, q=search or '', min_id=min_id, max_id=max_id, + limit=0, events_filter=events_filter, admins=admin_list or None + ) + + async def _load_next_chunk(self): + result = [] + self.request.limit = min(self.left, 100) + r = await self.client(self.request) + entities = {utils.get_peer_id(x): x + for x in itertools.chain(r.users, r.chats)} + + self.request.max_id = min((e.id for e in r.events), default=0) + for ev in r.events: + if isinstance(ev.action, + types.ChannelAdminLogEventActionEditMessage): + ev.action.prev_message._finish_init( + self.client, entities, self.entity) + + ev.action.new_message._finish_init( + self.client, entities, self.entity) + + elif isinstance(ev.action, + types.ChannelAdminLogEventActionDeleteMessage): + ev.action.message._finish_init( + self.client, entities, self.entity) + + result.append(custom.AdminLogEvent(ev, entities)) + + if len(r.events) < self.request.limit: + self.left = len(result) + + return result + + class ChatMethods(UserMethods): # region Public methods - @async_generator - async def iter_participants( + def iter_participants( self, entity, limit=None, *, search='', filter=None, aggressive=False, _total=None): """ @@ -62,138 +245,23 @@ class ChatMethods(UserMethods): matched :tl:`ChannelParticipant` type for channels/megagroups or :tl:`ChatParticipants` for normal chats. """ - if isinstance(filter, type): - if filter in (types.ChannelParticipantsBanned, - types.ChannelParticipantsKicked, - types.ChannelParticipantsSearch, - types.ChannelParticipantsContacts): - # These require a `q` parameter (support types for convenience) - filter = filter('') - else: - filter = filter() - - entity = await self.get_input_entity(entity) - if search and (filter - or not isinstance(entity, types.InputPeerChannel)): - # We need to 'search' ourselves unless we have a PeerChannel - search = search.lower() - - def filter_entity(ent): - return search in utils.get_display_name(ent).lower() or\ - search in (getattr(ent, 'username', '') or None).lower() - else: - def filter_entity(ent): - return True - - limit = float('inf') if limit is None else int(limit) - if isinstance(entity, types.InputPeerChannel): - if _total: - _total[0] = (await self( - functions.channels.GetFullChannelRequest(entity) - )).full_chat.participants_count - - if limit == 0: - return - - seen = set() - if aggressive and not filter: - requests = [functions.channels.GetParticipantsRequest( - channel=entity, - filter=types.ChannelParticipantsSearch(x), - offset=0, - limit=200, - hash=0 - ) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))] - else: - requests = [functions.channels.GetParticipantsRequest( - channel=entity, - filter=filter or types.ChannelParticipantsSearch(search), - offset=0, - limit=200, - hash=0 - )] - - while requests: - # Only care about the limit for the first request - # (small amount of people, won't be aggressive). - # - # Most people won't care about getting exactly 12,345 - # members so it doesn't really matter not to be 100% - # precise with being out of the offset/limit here. - requests[0].limit = min(limit - requests[0].offset, 200) - if requests[0].offset > limit: - break - - results = await self(requests) - for i in reversed(range(len(requests))): - participants = results[i] - if not participants.users: - requests.pop(i) - else: - requests[i].offset += len(participants.participants) - users = {user.id: user for user in participants.users} - for participant in participants.participants: - user = users[participant.user_id] - if not filter_entity(user) or user.id in seen: - continue - - seen.add(participant.user_id) - user = users[participant.user_id] - user.participant = participant - await yield_(user) - if len(seen) >= limit: - return - - elif isinstance(entity, types.InputPeerChat): - full = await self( - functions.messages.GetFullChatRequest(entity.chat_id)) - if not isinstance( - full.full_chat.participants, types.ChatParticipants): - # ChatParticipantsForbidden won't have ``.participants`` - if _total: - _total[0] = 0 - return - - if _total: - _total[0] = len(full.full_chat.participants.participants) - - have = 0 - users = {user.id: user for user in full.users} - for participant in full.full_chat.participants.participants: - user = users[participant.user_id] - if not filter_entity(user): - continue - have += 1 - if have > limit: - break - else: - user = users[participant.user_id] - user.participant = participant - await yield_(user) - else: - if _total: - _total[0] = 1 - if limit != 0: - user = await self.get_entity(entity) - if filter_entity(user): - user.participant = None - await yield_(user) + return _ParticipantsIter( + self, + limit, + entity=entity, + filter=filter, + search=search, + aggressive=aggressive + ) async def get_participants(self, *args, **kwargs): """ Same as `iter_participants`, but returns a `TotalList ` instead. """ - total = [0] - kwargs['_total'] = total - participants = helpers.TotalList() - async for x in self.iter_participants(*args, **kwargs): - participants.append(x) - participants.total = total[0] - return participants + return await self.iter_participants(*args, **kwargs).collect() - @async_generator - async def iter_admin_log( + def iter_admin_log( self, entity, limit=None, *, max_id=0, min_id=0, search=None, admins=None, join=None, leave=None, invite=None, restrict=None, unrestrict=None, ban=None, unban=None, promote=None, demote=None, @@ -285,66 +353,34 @@ class ChatMethods(UserMethods): Yields: Instances of `telethon.tl.custom.adminlogevent.AdminLogEvent`. """ - if limit is None: - limit = sys.maxsize - elif limit <= 0: - return - - if any((join, leave, invite, restrict, unrestrict, ban, unban, - promote, demote, info, settings, pinned, edit, delete)): - events_filter = types.ChannelAdminLogEventsFilter( - join=join, leave=leave, invite=invite, ban=restrict, - unban=unrestrict, kick=ban, unkick=unban, promote=promote, - demote=demote, info=info, settings=settings, pinned=pinned, - edit=edit, delete=delete - ) - else: - events_filter = None - - entity = await self.get_input_entity(entity) - - admin_list = [] - if admins: - if not utils.is_list_like(admins): - admins = (admins,) - - for admin in admins: - admin_list.append(await self.get_input_entity(admin)) - - request = functions.channels.GetAdminLogRequest( - entity, q=search or '', min_id=min_id, max_id=max_id, - limit=0, events_filter=events_filter, admins=admin_list or None + return _AdminLogIter( + self, + limit, + entity=entity, + admins=admins, + search=search, + min_id=min_id, + max_id=max_id, + join=join, + leave=leave, + invite=invite, + restrict=restrict, + unrestrict=unrestrict, + ban=ban, + unban=unban, + promote=promote, + demote=demote, + info=info, + settings=settings, + pinned=pinned, + edit=edit, + delete=delete ) - while limit > 0: - request.limit = min(limit, 100) - result = await self(request) - limit -= len(result.events) - entities = {utils.get_peer_id(x): x - for x in itertools.chain(result.users, result.chats)} - - request.max_id = min((e.id for e in result.events), default=0) - for ev in result.events: - if isinstance(ev.action, - types.ChannelAdminLogEventActionEditMessage): - ev.action.prev_message._finish_init(self, entities, entity) - ev.action.new_message._finish_init(self, entities, entity) - - elif isinstance(ev.action, - types.ChannelAdminLogEventActionDeleteMessage): - ev.action.message._finish_init(self, entities, entity) - - await yield_(custom.AdminLogEvent(ev, entities)) - - if len(result.events) < request.limit: - break async def get_admin_log(self, *args, **kwargs): """ Same as `iter_admin_log`, but returns a ``list`` instead. """ - admin_log = [] - async for x in self.iter_admin_log(*args, **kwargs): - admin_log.append(x) - return admin_log + return await self.iter_admin_log(*args, **kwargs).collect() # endregion From 202ce1f4946c44e27132c5fe355ef4ce45320262 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 27 Feb 2019 11:13:29 +0100 Subject: [PATCH 14/15] Remove async_generator from dependencies --- requirements.txt | 1 - setup.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 43e88e96..2b650ec4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,2 @@ pyaes rsa -async_generator diff --git a/setup.py b/setup.py index cec63880..1456cbcb 100755 --- a/setup.py +++ b/setup.py @@ -214,8 +214,7 @@ def main(): packages=find_packages(exclude=[ 'telethon_*', 'run_tests.py', 'try_telethon.py' ]), - install_requires=['pyaes', 'rsa', - 'async_generator'], + install_requires=['pyaes', 'rsa'], extras_require={ 'cryptg': ['cryptg'] } From c73b8eda26badb5b2d5fcb59a082e37f4f8ce3be Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 27 Feb 2019 11:24:47 +0100 Subject: [PATCH 15/15] Simplify filling RequestIter's buffer --- telethon/client/chats.py | 28 +++++++++------------------- telethon/client/dialogs.py | 20 ++++++++------------ telethon/client/messages.py | 28 ++++++++++------------------ telethon/requestiter.py | 16 ++++++++++++---- 4 files changed, 39 insertions(+), 53 deletions(-) diff --git a/telethon/client/chats.py b/telethon/client/chats.py index 7afe4655..83f8ce5e 100644 --- a/telethon/client/chats.py +++ b/telethon/client/chats.py @@ -68,7 +68,6 @@ class _ParticipantsIter(RequestIter): self.total = len(full.full_chat.participants.participants) - result = [] users = {user.id: user for user in full.users} for participant in full.full_chat.participants.participants: user = users[participant.user_id] @@ -77,26 +76,22 @@ class _ParticipantsIter(RequestIter): user = users[participant.user_id] user.participant = participant - result.append(user) + self.buffer.append(user) - self.left = len(result) - self.buffer = result + return True else: - result = [] self.total = 1 if self.limit != 0: user = await self.client.get_entity(entity) if self.filter_entity(user): user.participant = None - result.append(user) + self.buffer.append(user) - self.left = len(result) - self.buffer = result + return True async def _load_next_chunk(self): - result = [] if not self.requests: - return result + return True # Only care about the limit for the first request # (small amount of people, won't be aggressive). @@ -106,7 +101,7 @@ class _ParticipantsIter(RequestIter): # precise with being out of the offset/limit here. self.requests[0].limit = min(self.limit - self.requests[0].offset, 200) if self.requests[0].offset > self.limit: - return result + return True results = await self.client(self.requests) for i in reversed(range(len(self.requests))): @@ -125,9 +120,7 @@ class _ParticipantsIter(RequestIter): self.seen.add(participant.user_id) user = users[participant.user_id] user.participant = participant - result.append(user) - - return result + self.buffer.append(user) class _AdminLogIter(RequestIter): @@ -163,7 +156,6 @@ class _AdminLogIter(RequestIter): ) async def _load_next_chunk(self): - result = [] self.request.limit = min(self.left, 100) r = await self.client(self.request) entities = {utils.get_peer_id(x): x @@ -184,12 +176,10 @@ class _AdminLogIter(RequestIter): ev.action.message._finish_init( self.client, entities, self.entity) - result.append(custom.AdminLogEvent(ev, entities)) + self.buffer.append(custom.AdminLogEvent(ev, entities)) if len(r.events) < self.request.limit: - self.left = len(result) - - return result + return True class ChatMethods(UserMethods): diff --git a/telethon/client/dialogs.py b/telethon/client/dialogs.py index bf3f903f..453dfdec 100644 --- a/telethon/client/dialogs.py +++ b/telethon/client/dialogs.py @@ -29,8 +29,6 @@ class _DialogsIter(RequestIter): self.ignore_migrated = ignore_migrated async def _load_next_chunk(self): - result = [] - self.request.limit = min(self.left, 100) r = await self.client(self.request) @@ -62,34 +60,32 @@ class _DialogsIter(RequestIter): if not self.ignore_migrated or getattr( cd.entity, 'migrated_to', None) is None: - result.append(cd) + self.buffer.append(cd) if len(r.dialogs) < self.request.limit\ or not isinstance(r, types.messages.DialogsSlice): # Less than we requested means we reached the end, or # we didn't get a DialogsSlice which means we got all. - self.left = len(result) - - self.request.offset_date = r.messages[-1].date - self.request.offset_peer =\ - entities[utils.get_peer_id(r.dialogs[-1].peer)] + return True if self.request.offset_id == r.messages[-1].id: # In some very rare cases this will get stuck in an infinite # loop, where the offsets will get reused over and over. If # the new offset is the same as the one before, break already. - self.left = len(result) + return True self.request.offset_id = r.messages[-1].id self.request.exclude_pinned = True - return result + self.request.offset_date = r.messages[-1].date + self.request.offset_peer =\ + entities[utils.get_peer_id(r.dialogs[-1].peer)] class _DraftsIter(RequestIter): async def _init(self, **kwargs): r = await self.client(functions.messages.GetAllDraftsRequest()) - self.buffer = [custom.Draft._from_update(self.client, u) - for u in r.updates] + self.buffer.extend(custom.Draft._from_update(self.client, u) + for u in r.updates) async def _load_next_chunk(self): return [] diff --git a/telethon/client/messages.py b/telethon/client/messages.py index 87d49be6..e87d80a0 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -139,8 +139,6 @@ class _MessagesIter(RequestIter): self.last_id = 0 if self.reverse else float('inf') async def _load_next_chunk(self): - result = [] - self.request.limit = min(self.left, self.batch_size) if self.reverse and self.request.limit != self.batch_size: # Remember that we need -limit when going in reverse @@ -159,8 +157,7 @@ class _MessagesIter(RequestIter): continue if not self._message_in_range(message): - self.left = len(result) - break + return True # There has been reports that on bad connections this method # was returning duplicated IDs sometimes. Using ``last_id`` @@ -168,15 +165,15 @@ class _MessagesIter(RequestIter): # IDs are returned in descending order (or asc if reverse). self.last_id = message.id message._finish_init(self.client, entities, self.entity) - result.append(message) + self.buffer.append(message) if len(r.messages) < self.request.limit: - self.left = len(result) + return True # Get the last message that's not empty (in some rare cases # it can happen that the last message is :tl:`MessageEmpty`) - if result: - self._update_offset(result[-1]) + if self.buffer: + self._update_offset(self.buffer[-1]) else: # There are some cases where all the messages we get start # being empty. This can happen on migrated mega-groups if @@ -184,9 +181,7 @@ class _MessagesIter(RequestIter): # acts incredibly weird sometimes. Messages are returned but # only "empty", not their contents. If this is the case we # should just give up since there won't be any new Message. - self.left = len(result) - - return result + return True def _message_in_range(self, message): """ @@ -258,7 +253,7 @@ class _IDsIter(RequestIter): from_id = await self.client.get_peer_id(entity) if isinstance(r, types.messages.MessagesNotModified): - self.buffer = [None] * len(ids) + self.buffer.extend(None for _ in ids) return entities = {utils.get_peer_id(x): x @@ -272,19 +267,16 @@ class _IDsIter(RequestIter): # The passed message IDs may not belong to the desired entity # since the user can enter arbitrary numbers which can belong to # arbitrary chats. Validate these unless ``from_id is None``. - result = [] for message in r.messages: if isinstance(message, types.MessageEmpty) or ( from_id and message.chat_id != from_id): - result.append(None) + self.buffer.append(None) else: message._finish_init(self.client, entities, entity) - result.append(message) - - self.buffer = result + self.buffer.append(message) async def _load_next_chunk(self): - return [] # no next chunk, all done in init + return True # no next chunk, all done in init class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): diff --git a/telethon/requestiter.py b/telethon/requestiter.py index ca9c8d72..111d4509 100644 --- a/telethon/requestiter.py +++ b/telethon/requestiter.py @@ -55,7 +55,8 @@ class RequestIter(abc.ABC): """ async def __anext__(self): - if self.buffer is (): + if self.buffer is None: + self.buffer = [] await self._init(**self.kwargs) if self.left <= 0: # <= 0 because subclasses may change it @@ -71,7 +72,9 @@ class RequestIter(abc.ABC): self.last_load = time.time() self.index = 0 - self.buffer = await self._load_next_chunk() + self.buffer = [] + if await self._load_next_chunk(): + self.left = len(self.buffer) if not self.buffer: raise StopAsyncIteration @@ -82,7 +85,7 @@ class RequestIter(abc.ABC): return result def __aiter__(self): - self.buffer = () + self.buffer = None self.index = 0 self.last_load = 0 self.left = self.limit @@ -113,7 +116,12 @@ class RequestIter(abc.ABC): async def _load_next_chunk(self): """ Called when the next chunk is necessary. - It should *always* return a `list`. + + It should extend the `buffer` with new items. + + It should return ``True`` if it's the last chunk, + after which moment the method won't be called again + during the same iteration. """ raise NotImplementedError