Make use of the async_generator module

This commit is contained in:
Lonami Exo
2018-06-10 21:50:28 +02:00
parent 15ef302428
commit 8be6adeab4
3 changed files with 21 additions and 11 deletions

View File

@@ -5,6 +5,8 @@ import time
import warnings
from collections import UserList
from async_generator import async_generator, yield_
from .messageparse import MessageParseMethods
from .uploads import UploadMethods
from .. import utils
@@ -19,6 +21,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
# region Message retrieval
@async_generator
async def iter_messages(
self, entity, limit=None, offset_date=None, offset_id=0,
max_id=0, min_id=0, add_offset=0, search=None, filter=None,
@@ -114,7 +117,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
if not utils.is_list_like(ids):
ids = (ids,)
async for x in self._iter_ids(entity, ids, total=_total):
yield x
await yield_(x)
return
# Telegram doesn't like min_id/max_id. If these IDs are low enough
@@ -202,7 +205,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
# IDs are returned in descending order.
last_id = message.id
yield custom.Message(self, message, entities, entity)
await yield_(custom.Message(self, message, entities, entity))
have += 1
if len(r.messages) < request.limit:
@@ -620,6 +623,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
# region Private methods
@async_generator
async def _iter_ids(self, entity, ids, total):
"""
Special case for `iter_messages` when it should only fetch some IDs.
@@ -634,7 +638,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
if isinstance(r, types.messages.MessagesNotModified):
for _ in ids:
yield None
await yield_(None)
return
entities = {utils.get_peer_id(x): x
@@ -644,8 +648,8 @@ class MessageMethods(UploadMethods, MessageParseMethods):
# we asked them for, so we don't need to check it ourselves.
for message in r.messages:
if isinstance(message, types.MessageEmpty):
yield None
await yield_(None)
else:
yield custom.Message(self, message, entities, entity)
await yield_(custom.Message(self, message, entities, entity))
# endregion