diff --git a/telethon/client/auth.py b/telethon/client/auth.py index d30ba6d2..9ca978c5 100644 --- a/telethon/client/auth.py +++ b/telethon/client/auth.py @@ -375,11 +375,6 @@ class AuthMethods(MessageParseMethods, UserMethods): self._self_input_peer = utils.get_input_peer(user, allow_self=False) self._authorized = True - # `catch_up` will getDifference from pts = 1, date = 1 (ignored) - # to fetch all updates (and obtain necessary access hashes) if - # the ``pts is None``. - self._old_pts_date = (None, None) - return user async def send_code_request(self, phone, *, force_sms=False): @@ -436,8 +431,7 @@ class AuthMethods(MessageParseMethods, UserMethods): self._bot = None self._self_input_peer = None self._authorized = False - self._old_pts_date = (None, None) - self._new_pts_date = (None, None) + self._state_cache.reset() await self.disconnect() self.session.delete() diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index ebc82b6c..90ffe813 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -13,6 +13,7 @@ from ..sessions import Session, SQLiteSession, MemorySession from ..tl import TLObject, functions, types from ..tl.alltlobjects import LAYER from ..entitycache import EntityCache +from ..statecache import StateCache DEFAULT_DC_ID = 4 DEFAULT_IPV4_IP = '149.154.167.51' @@ -306,13 +307,8 @@ class TelegramBaseClient(abc.ABC): self._authorized = None # None = unknown, False = no, True = yes # Update state (for catching up after a disconnection) - # - # We only care about the pts and the date. By using a tuple which - # is lightweight and immutable we can easily copy them around to - # each update in case they need to fetch missing entities. - state = self.session.get_update_state(0) - self._old_pts_date = (state.pts, state.date) if state else (None, None) - self._new_pts_date = (None, None) + # TODO Get state from channels too + self._state_cache = StateCache(self.session.get_update_state(0)) # Some further state for subclasses self._event_builders = [] @@ -395,15 +391,14 @@ class TelegramBaseClient(abc.ABC): async def _disconnect_coro(self): await self._disconnect() - pts, date = self._new_pts_date - if pts: - self.session.set_update_state(0, types.updates.State( - pts=pts, - qts=0, - date=date or datetime.now(), - seq=0, - unread_count=0 - )) + pts, date = self._state_cache[None] + self.session.set_update_state(0, types.updates.State( + pts=pts, + qts=0, + date=date, + seq=0, + unread_count=0 + )) self.session.close() diff --git a/telethon/client/updates.py b/telethon/client/updates.py index 5d313097..9b3e3190 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -8,6 +8,7 @@ from .users import UserMethods from .. import events, utils, errors from ..tl import types, functions from ..events.common import EventCommon +from ..statecache import StateCache class UpdateMethods(UserMethods): @@ -135,14 +136,7 @@ class UpdateMethods(UserMethods): This can also be used to forcibly fetch new updates if there are any. """ - # TODO Since which state should we catch up? - if all(self._new_pts_date): - pts, date = self._new_pts_date - elif all(self._old_pts_date): - pts, date = self._old_pts_date - else: - return - + pts, date = self._state_cache[None] self.session.catching_up = True try: while True: @@ -192,7 +186,7 @@ class UpdateMethods(UserMethods): pass finally: # TODO Save new pts to session - self._new_pts_date = (pts, date) + self._state_cache._pts_date = (pts, date) self.session.catching_up = False # endregion @@ -211,19 +205,15 @@ class UpdateMethods(UserMethods): itertools.chain(update.users, update.chats)} for u in update.updates: self._process_update(u, entities) - - self._new_pts_date = (self._new_pts_date[0], update.date) elif isinstance(update, types.UpdateShort): self._process_update(update.update) - self._new_pts_date = (self._new_pts_date[0], update.date) else: self._process_update(update) - # TODO Should this be done before or after? - self._update_pts_date(update) + self._state_cache.update(update) def _process_update(self, update, entities=None): - update._pts_date = self._new_pts_date + update._pts_date = self._state_cache[StateCache.get_channel_id(update)] update._entities = entities or {} if self._updates_queue is None: self._loop.create_task(self._dispatch_update(update)) @@ -233,17 +223,7 @@ class UpdateMethods(UserMethods): self._dispatching_updates_queue.set() self._loop.create_task(self._dispatch_queue_updates()) - self._update_pts_date(update) - - def _update_pts_date(self, update): - pts, date = self._new_pts_date - if getattr(update, 'pts', None): - pts = update.pts - - if getattr(update, 'date', None): - date = update.date - - self._new_pts_date = (pts, date) + self._state_cache.update(update) async def _update_loop(self): # Pings' ID don't really need to be secure, just "random" @@ -416,7 +396,12 @@ class EventBuilderDict: """ # Fetch since the last known pts/date before this update arrived, # in order to fetch this update at full. - pts, date = self.update._pts_date + pts_date = self.update._pts_date + if not isinstance(pts_date, tuple): + # TODO Handle channels, and handle this more nicely + return + + pts, date = pts_date if not pts: return diff --git a/telethon/statecache.py b/telethon/statecache.py new file mode 100644 index 00000000..406ee80b --- /dev/null +++ b/telethon/statecache.py @@ -0,0 +1,117 @@ +import datetime + +from .tl import types + + +class StateCache: + """ + In-memory update state cache, defaultdict-like behaviour. + """ + def __init__(self, initial): + # We only care about the pts and the date. By using a tuple which + # is lightweight and immutable we can easily copy them around to + # each update in case they need to fetch missing entities. + if initial: + self._pts_date = initial.pts, initial.date + else: + self._pts_date = 1, datetime.datetime.now() + + def reset(self): + self.__dict__.clear() + self._pts_date = (1, 1) + + # TODO Call this when receiving responses too...? + def update( + self, + update, + *, + channel_id=None, + has_pts=( + types.UpdateNewMessage, + types.UpdateDeleteMessages, + types.UpdateReadHistoryInbox, + types.UpdateReadHistoryOutbox, + types.UpdateWebPage, + types.UpdateReadMessagesContents, + types.UpdateEditMessage, + types.updates.State, + types.updates.DifferenceTooLong, + types.UpdateShortMessage, + types.UpdateShortChatMessage, + types.UpdateShortSentMessage + ), + has_date=( + types.UpdateUserPhoto, + types.UpdateEncryption, + types.UpdateEncryptedMessagesRead, + types.UpdateChatParticipantAdd, + types.updates.DifferenceEmpty, + types.UpdateShortMessage, + types.UpdateShortChatMessage, + types.UpdateShort, + types.UpdatesCombined, + types.Updates, + types.UpdateShortSentMessage, + ), + has_channel_pts=( + types.UpdateChannelTooLong, + types.UpdateNewChannelMessage, + types.UpdateDeleteChannelMessages, + types.UpdateEditChannelMessage, + types.UpdateChannelWebPage, + types.updates.ChannelDifferenceEmpty, + types.updates.ChannelDifferenceTooLong, + types.updates.ChannelDifference + ) + ): + """ + Update the state with the given update. + """ + has_pts = isinstance(update, has_pts) + has_date = isinstance(update, has_date) + has_channel_pts = isinstance(update, has_channel_pts) + if has_pts and has_date: + self._pts_date = update.pts, update.date + elif has_pts: + self._pts_date = update.pts, self._pts_date[1] + elif has_date: + self._pts_date = self._pts_date[0], update.date + + if has_channel_pts: + if channel_id is None: + channel_id = self.get_channel_id(update) + + if channel_id is None: + pass # TODO log, but shouldn't happen + else: + self.__dict__[channel_id] = update.pts + + @staticmethod + def get_channel_id( + update, + has_channel_id=( + types.UpdateChannelTooLong, + types.UpdateDeleteChannelMessages, + types.UpdateChannelWebPage + ), + has_message=( + types.UpdateNewChannelMessage, + types.UpdateEditChannelMessage + ) + ): + # Will only fail for *difference, where channel_id is known + if isinstance(update, has_channel_id): + return update.channel_id + elif isinstance(update, has_message): + return update.message.to_id.channel_id + else: + return None + + def __getitem__(self, item): + """ + Gets the corresponding ``(pts, date)`` for the given ID or peer, + """ + if item is None: + return self._pts_date + else: + return self.__dict__.get(item, 1)