Fix pts from channels is different (#1160)

This commit is contained in:
Lonami Exo 2019-04-21 13:56:14 +02:00
parent 8edbfbdced
commit c1880c9191
4 changed files with 141 additions and 50 deletions

View File

@ -375,11 +375,6 @@ class AuthMethods(MessageParseMethods, UserMethods):
self._self_input_peer = utils.get_input_peer(user, allow_self=False) self._self_input_peer = utils.get_input_peer(user, allow_self=False)
self._authorized = True self._authorized = True
# `catch_up` will getDifference from pts = 1, date = 1 (ignored)
# to fetch all updates (and obtain necessary access hashes) if
# the ``pts is None``.
self._old_pts_date = (None, None)
return user return user
async def send_code_request(self, phone, *, force_sms=False): async def send_code_request(self, phone, *, force_sms=False):
@ -436,8 +431,7 @@ class AuthMethods(MessageParseMethods, UserMethods):
self._bot = None self._bot = None
self._self_input_peer = None self._self_input_peer = None
self._authorized = False self._authorized = False
self._old_pts_date = (None, None) self._state_cache.reset()
self._new_pts_date = (None, None)
await self.disconnect() await self.disconnect()
self.session.delete() self.session.delete()

View File

@ -13,6 +13,7 @@ from ..sessions import Session, SQLiteSession, MemorySession
from ..tl import TLObject, functions, types from ..tl import TLObject, functions, types
from ..tl.alltlobjects import LAYER from ..tl.alltlobjects import LAYER
from ..entitycache import EntityCache from ..entitycache import EntityCache
from ..statecache import StateCache
DEFAULT_DC_ID = 4 DEFAULT_DC_ID = 4
DEFAULT_IPV4_IP = '149.154.167.51' DEFAULT_IPV4_IP = '149.154.167.51'
@ -306,13 +307,8 @@ class TelegramBaseClient(abc.ABC):
self._authorized = None # None = unknown, False = no, True = yes self._authorized = None # None = unknown, False = no, True = yes
# Update state (for catching up after a disconnection) # Update state (for catching up after a disconnection)
# # TODO Get state from channels too
# We only care about the pts and the date. By using a tuple which self._state_cache = StateCache(self.session.get_update_state(0))
# is lightweight and immutable we can easily copy them around to
# each update in case they need to fetch missing entities.
state = self.session.get_update_state(0)
self._old_pts_date = (state.pts, state.date) if state else (None, None)
self._new_pts_date = (None, None)
# Some further state for subclasses # Some further state for subclasses
self._event_builders = [] self._event_builders = []
@ -395,12 +391,11 @@ class TelegramBaseClient(abc.ABC):
async def _disconnect_coro(self): async def _disconnect_coro(self):
await self._disconnect() await self._disconnect()
pts, date = self._new_pts_date pts, date = self._state_cache[None]
if pts:
self.session.set_update_state(0, types.updates.State( self.session.set_update_state(0, types.updates.State(
pts=pts, pts=pts,
qts=0, qts=0,
date=date or datetime.now(), date=date,
seq=0, seq=0,
unread_count=0 unread_count=0
)) ))

View File

@ -8,6 +8,7 @@ from .users import UserMethods
from .. import events, utils, errors from .. import events, utils, errors
from ..tl import types, functions from ..tl import types, functions
from ..events.common import EventCommon from ..events.common import EventCommon
from ..statecache import StateCache
class UpdateMethods(UserMethods): class UpdateMethods(UserMethods):
@ -135,14 +136,7 @@ class UpdateMethods(UserMethods):
This can also be used to forcibly fetch new updates if there are any. This can also be used to forcibly fetch new updates if there are any.
""" """
# TODO Since which state should we catch up? pts, date = self._state_cache[None]
if all(self._new_pts_date):
pts, date = self._new_pts_date
elif all(self._old_pts_date):
pts, date = self._old_pts_date
else:
return
self.session.catching_up = True self.session.catching_up = True
try: try:
while True: while True:
@ -192,7 +186,7 @@ class UpdateMethods(UserMethods):
pass pass
finally: finally:
# TODO Save new pts to session # TODO Save new pts to session
self._new_pts_date = (pts, date) self._state_cache._pts_date = (pts, date)
self.session.catching_up = False self.session.catching_up = False
# endregion # endregion
@ -211,19 +205,15 @@ class UpdateMethods(UserMethods):
itertools.chain(update.users, update.chats)} itertools.chain(update.users, update.chats)}
for u in update.updates: for u in update.updates:
self._process_update(u, entities) self._process_update(u, entities)
self._new_pts_date = (self._new_pts_date[0], update.date)
elif isinstance(update, types.UpdateShort): elif isinstance(update, types.UpdateShort):
self._process_update(update.update) self._process_update(update.update)
self._new_pts_date = (self._new_pts_date[0], update.date)
else: else:
self._process_update(update) self._process_update(update)
# TODO Should this be done before or after? self._state_cache.update(update)
self._update_pts_date(update)
def _process_update(self, update, entities=None): def _process_update(self, update, entities=None):
update._pts_date = self._new_pts_date update._pts_date = self._state_cache[StateCache.get_channel_id(update)]
update._entities = entities or {} update._entities = entities or {}
if self._updates_queue is None: if self._updates_queue is None:
self._loop.create_task(self._dispatch_update(update)) self._loop.create_task(self._dispatch_update(update))
@ -233,17 +223,7 @@ class UpdateMethods(UserMethods):
self._dispatching_updates_queue.set() self._dispatching_updates_queue.set()
self._loop.create_task(self._dispatch_queue_updates()) self._loop.create_task(self._dispatch_queue_updates())
self._update_pts_date(update) self._state_cache.update(update)
def _update_pts_date(self, update):
pts, date = self._new_pts_date
if getattr(update, 'pts', None):
pts = update.pts
if getattr(update, 'date', None):
date = update.date
self._new_pts_date = (pts, date)
async def _update_loop(self): async def _update_loop(self):
# Pings' ID don't really need to be secure, just "random" # Pings' ID don't really need to be secure, just "random"
@ -416,7 +396,12 @@ class EventBuilderDict:
""" """
# Fetch since the last known pts/date before this update arrived, # Fetch since the last known pts/date before this update arrived,
# in order to fetch this update at full. # in order to fetch this update at full.
pts, date = self.update._pts_date pts_date = self.update._pts_date
if not isinstance(pts_date, tuple):
# TODO Handle channels, and handle this more nicely
return
pts, date = pts_date
if not pts: if not pts:
return return

117
telethon/statecache.py Normal file
View File

@ -0,0 +1,117 @@
import datetime
from .tl import types
class StateCache:
"""
In-memory update state cache, defaultdict-like behaviour.
"""
def __init__(self, initial):
# We only care about the pts and the date. By using a tuple which
# is lightweight and immutable we can easily copy them around to
# each update in case they need to fetch missing entities.
if initial:
self._pts_date = initial.pts, initial.date
else:
self._pts_date = 1, datetime.datetime.now()
def reset(self):
self.__dict__.clear()
self._pts_date = (1, 1)
# TODO Call this when receiving responses too...?
def update(
self,
update,
*,
channel_id=None,
has_pts=(
types.UpdateNewMessage,
types.UpdateDeleteMessages,
types.UpdateReadHistoryInbox,
types.UpdateReadHistoryOutbox,
types.UpdateWebPage,
types.UpdateReadMessagesContents,
types.UpdateEditMessage,
types.updates.State,
types.updates.DifferenceTooLong,
types.UpdateShortMessage,
types.UpdateShortChatMessage,
types.UpdateShortSentMessage
),
has_date=(
types.UpdateUserPhoto,
types.UpdateEncryption,
types.UpdateEncryptedMessagesRead,
types.UpdateChatParticipantAdd,
types.updates.DifferenceEmpty,
types.UpdateShortMessage,
types.UpdateShortChatMessage,
types.UpdateShort,
types.UpdatesCombined,
types.Updates,
types.UpdateShortSentMessage,
),
has_channel_pts=(
types.UpdateChannelTooLong,
types.UpdateNewChannelMessage,
types.UpdateDeleteChannelMessages,
types.UpdateEditChannelMessage,
types.UpdateChannelWebPage,
types.updates.ChannelDifferenceEmpty,
types.updates.ChannelDifferenceTooLong,
types.updates.ChannelDifference
)
):
"""
Update the state with the given update.
"""
has_pts = isinstance(update, has_pts)
has_date = isinstance(update, has_date)
has_channel_pts = isinstance(update, has_channel_pts)
if has_pts and has_date:
self._pts_date = update.pts, update.date
elif has_pts:
self._pts_date = update.pts, self._pts_date[1]
elif has_date:
self._pts_date = self._pts_date[0], update.date
if has_channel_pts:
if channel_id is None:
channel_id = self.get_channel_id(update)
if channel_id is None:
pass # TODO log, but shouldn't happen
else:
self.__dict__[channel_id] = update.pts
@staticmethod
def get_channel_id(
update,
has_channel_id=(
types.UpdateChannelTooLong,
types.UpdateDeleteChannelMessages,
types.UpdateChannelWebPage
),
has_message=(
types.UpdateNewChannelMessage,
types.UpdateEditChannelMessage
)
):
# Will only fail for *difference, where channel_id is known
if isinstance(update, has_channel_id):
return update.channel_id
elif isinstance(update, has_message):
return update.message.to_id.channel_id
else:
return None
def __getitem__(self, item):
"""
Gets the corresponding ``(pts, date)`` for the given ID or peer,
"""
if item is None:
return self._pts_date
else:
return self.__dict__.get(item, 1)