diff --git a/telethon/client/updates.py b/telethon/client/updates.py index b15853b7..ea129cf4 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -2,10 +2,12 @@ import asyncio import itertools import random import time +import datetime from .users import UserMethods from .. import events, utils, errors from ..tl import types, functions +from ..events.common import EventCommon class UpdateMethods(UserMethods): @@ -260,17 +262,17 @@ class UpdateMethods(UserMethods): built = EventBuilderDict(self, update) if self._conversations: for conv in self._conversations.values(): - if built[events.NewMessage]: + if await built.get(events.NewMessage): conv._on_new_message(built[events.NewMessage]) - if built[events.MessageEdited]: + if await built.get(events.MessageEdited): conv._on_edit(built[events.MessageEdited]) - if built[events.MessageRead]: + if await built.get(events.MessageRead): conv._on_read(built[events.MessageRead]) if conv._custom: await conv._check_custom(built) for builder, callback in self._event_builders: - event = built[type(builder)] + event = await built.get(type(builder)) if not event: continue @@ -322,15 +324,54 @@ class EventBuilderDict: self.update = update def __getitem__(self, builder): + return self.__dict__[builder] + + async def get(self, builder): try: return self.__dict__[builder] except KeyError: event = self.__dict__[builder] = builder.build(self.update) - if event: + if isinstance(event, EventCommon): event.original_update = self.update - if hasattr(event, '_set_client'): - event._set_client(self.client) - else: - event._client = self.client + event._set_client(self.client) + if not event._load_entities(): + await self.get_difference() + if not event._load_entities(): + self.client._log[__name__].info( + 'Could not find all entities for update.pts = %s', + getattr(self.update, 'pts', None) + ) + elif event: + # Actually a :tl:`Update`, not much processing to do + event._client = self.client return event + + async def get_difference(self): + """ + Calls :tl:`updates.getDifference`, which fills the entities cache + (always done by `__call__`) and lets us know about the full entities. + """ + pts = getattr(self.update, 'pts', None) + if not pts: + return + + date = getattr(self.update, 'date', None) + if date: + # Get the difference from one second ago to now + date -= datetime.timedelta(seconds=1) + else: + # No date known, 1 is the earliest date that works + date = 1 + + self.client._log[__name__].debug('Getting difference for entities') + result = await self.client(functions.updates.GetDifferenceRequest( + pts, date, 0 + )) + + if isinstance(result, (types.updates.Difference, + types.updates.DifferenceSlice)): + self.update._entities.update({ + utils.get_peer_id(x): x for x in + itertools.chain(result.users, result.chats) + }) diff --git a/telethon/events/common.py b/telethon/events/common.py index 1470585a..db644efd 100644 --- a/telethon/events/common.py +++ b/telethon/events/common.py @@ -146,17 +146,23 @@ class EventCommon(ChatGetter, abc.ABC): Setter so subclasses can act accordingly when the client is set. """ self._client = client - self._chat = self._entities.get(self.chat_id) - if not self._chat: - return + def _load_entities(self): + """ + Must load all the entities it needs from cache, and + return ``False`` if it could not find all of them. + """ + # TODO Make sure all subclasses implement this + self._chat = self._entities.get(self.chat_id) try: self._input_chat = utils.get_input_peer(self._chat) except TypeError: try: self._input_chat = self._client._entity_cache[self._chat_peer] except KeyError: - self._input_chat = None + return False + + return True @property def client(self):