Save pts and date in a tuple for immutability

This way it is easy and cheap to copy the two required values
to all incoming updates in case we need to getDifference since
the previous pts/date to fetch entities.

This is still a work in progress.
This commit is contained in:
Lonami Exo
2019-04-10 21:09:15 +04:00
parent bec0fa414e
commit 9965cda968
4 changed files with 68 additions and 76 deletions

View File

@@ -135,23 +135,19 @@ class UpdateMethods(UserMethods):
This can also be used to forcibly fetch new updates if there are any.
"""
state = self._new_state if self._old_state_is_new else self._old_state
if not self._old_state_is_new and self._new_state:
max_pts = self._new_state.pts
# TODO Since which state should we catch up?
if all(self._new_pts_date):
pts, date = self._new_pts_date
elif all(self._old_pts_date):
pts, date = self._new_pts_date
else:
max_pts = float('inf')
# No known state -> catch up since the beginning (date is ignored).
# Note: pts = 0 is invalid (and so is no date/unix timestamp = 0).
if not state:
state = types.updates.State(
1, 0, datetime.datetime.now(tz=datetime.timezone.utc), 0, 0)
return
self.session.catching_up = True
try:
while True:
d = await self(functions.updates.GetDifferenceRequest(
state.pts, state.date, state.qts
pts, date, 0
))
if isinstance(d, (types.updates.DifferenceSlice,
types.updates.Difference)):
@@ -160,7 +156,8 @@ class UpdateMethods(UserMethods):
else:
state = d.intermediate_state
await self._handle_update(types.Updates(
pts, date = state.pts, state.date
self._handle_update(types.Updates(
users=d.users,
chats=d.chats,
date=state.date,
@@ -171,6 +168,7 @@ class UpdateMethods(UserMethods):
]
))
# TODO Implement upper limit (max_pts)
# We don't want to fetch updates we already know about.
#
# We may still get duplicates because the Difference
@@ -184,29 +182,27 @@ class UpdateMethods(UserMethods):
# there would be duplicate updates since we know about
# some). This can be used to detect collisions (i.e.
# it would return an update we have already seen).
if state.pts >= max_pts:
break
else:
if isinstance(d, types.updates.DifferenceEmpty):
state.date = d.date
state.seq = d.seq
date = d.date
elif isinstance(d, types.updates.DifferenceTooLong):
state.pts = d.pts
pts = d.pts
break
except (ConnectionError, asyncio.CancelledError):
pass
finally:
self._old_state = None
self._new_state = state
self._old_state_is_new = True
self.session.set_update_state(0, state)
# TODO Save new pts to session
self._new_pts_date = (pts, date)
self.session.catching_up = False
# endregion
# region Private methods
async def _handle_update(self, update):
# It is important to not make _handle_update async because we rely on
# the order that the updates arrive in to update the pts and date to
# be always-increasing. There is also no need to make this async.
def _handle_update(self, update):
self.session.process_entities(update)
self._entity_cache.add(update)
@@ -214,40 +210,39 @@ class UpdateMethods(UserMethods):
entities = {utils.get_peer_id(x): x for x in
itertools.chain(update.users, update.chats)}
for u in update.updates:
u._entities = entities
await self._handle_update(u)
self._process_update(u, entities)
self._new_pts_date = (self._new_pts_date[0], update.date)
elif isinstance(update, types.UpdateShort):
await self._handle_update(update.update)
self._process_update(update.update)
self._new_pts_date = (self._new_pts_date[0], update.date)
else:
update._entities = getattr(update, '_entities', {})
if self._updates_queue is None:
self._loop.create_task(self._dispatch_update(update))
else:
self._updates_queue.put_nowait(update)
if not self._dispatching_updates_queue.is_set():
self._dispatching_updates_queue.set()
self._loop.create_task(self._dispatch_queue_updates())
self._process_update(update)
# TODO make use of need_diff
need_diff = False
# TODO Should this be done before or after?
self._update_pts_date(update)
def _process_update(self, update, entities=None):
update._entities = entities or {}
if self._updates_queue is None:
self._loop.create_task(self._dispatch_update(update))
else:
self._updates_queue.put_nowait(update)
if not self._dispatching_updates_queue.is_set():
self._dispatching_updates_queue.set()
self._loop.create_task(self._dispatch_queue_updates())
self._update_pts_date(update)
def _update_pts_date(self, update):
pts, date = self._new_pts_date
if getattr(update, 'pts', None):
if not self._new_state:
self._new_state = types.updates.State(
update.pts,
0,
getattr(update, 'date', datetime.datetime.now(tz=datetime.timezone.utc)),
getattr(update, 'seq', 0),
0
)
else:
if self._new_state.pts and (update.pts - self._new_state.pts) > 1:
need_diff = True
pts = update.pts
self._new_state.pts = update.pts
if hasattr(update, 'date'):
self._new_state.date = update.date
if hasattr(update, 'seq'):
self._new_state.seq = update.seq
if getattr(update, 'date', None):
date = update.date
self._new_pts_date = (pts, date)
async def _update_loop(self):
# Pings' ID don't really need to be secure, just "random"
@@ -368,10 +363,6 @@ class UpdateMethods(UserMethods):
# If a disconnection occurs, the old known state will be
# the latest one we were aware of, so we can catch up since
# the most recent state we were aware of.
# TODO Ideally we set _old_state = _new_state *on* disconnect,
# not *after* we managed to reconnect since perhaps an update
# arrives just before we can get started.
self._old_state_is_new = True
await self.catch_up()
self._log[__name__].info('Successfully fetched missed updates')