mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-06-17 10:36:37 +00:00
639 lines
22 KiB
Python
639 lines
22 KiB
Python
import asyncio
|
|
import datetime
|
|
import logging
|
|
import time
|
|
from typing import Dict, List, Optional, Set, Tuple
|
|
|
|
from ...tl import Request, abcs, functions, types
|
|
from ..chat import ChatHashCache
|
|
from ..session import ChannelState, UpdateState
|
|
from .adaptor import adapt, pts_info_from_update
|
|
from .defs import (
|
|
BOT_CHANNEL_DIFF_LIMIT,
|
|
ENTRY_ACCOUNT,
|
|
ENTRY_SECRET,
|
|
LOG_LEVEL_TRACE,
|
|
NO_PTS,
|
|
NO_SEQ,
|
|
NO_UPDATES_TIMEOUT,
|
|
POSSIBLE_GAP_TIMEOUT,
|
|
USER_CHANNEL_DIFF_LIMIT,
|
|
Entry,
|
|
Gap,
|
|
PossibleGap,
|
|
PrematureEndReason,
|
|
State,
|
|
)
|
|
|
|
|
|
def next_updates_deadline() -> float:
|
|
return asyncio.get_running_loop().time() + NO_UPDATES_TIMEOUT
|
|
|
|
|
|
def epoch() -> datetime.datetime:
|
|
return datetime.datetime(*time.gmtime(0)[:6]).replace(tzinfo=datetime.timezone.utc)
|
|
|
|
|
|
# https://core.telegram.org/api/updates#message-related-event-sequences.
|
|
class MessageBox:
|
|
__slots__ = (
|
|
"_log",
|
|
"map",
|
|
"date",
|
|
"seq",
|
|
"possible_gaps",
|
|
"getting_diff_for",
|
|
"next_deadline",
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
log: Optional[logging.Logger] = None,
|
|
) -> None:
|
|
self._log = log or logging.getLogger("telethon.messagebox")
|
|
self.map: Dict[Entry, State] = {}
|
|
self.date = epoch()
|
|
self.seq = NO_SEQ
|
|
self.possible_gaps: Dict[Entry, PossibleGap] = {}
|
|
self.getting_diff_for: Set[Entry] = set()
|
|
self.next_deadline: Optional[Entry] = None
|
|
|
|
if __debug__:
|
|
self._trace("MessageBox initialized")
|
|
|
|
def _trace(self, msg: str, *args: object) -> None:
|
|
# Calls to trace can't really be removed beforehand without some dark magic.
|
|
# So every call to trace is prefixed with `if __debug__`` instead, to remove
|
|
# it when using `python -O`. Probably unnecessary, but it's nice to avoid
|
|
# paying the cost for something that is not used.
|
|
self._log.log(
|
|
LOG_LEVEL_TRACE,
|
|
"Current MessageBox state: seq = %r, date = %s, map = %r",
|
|
self.seq,
|
|
self.date.isoformat(),
|
|
self.map,
|
|
)
|
|
self._log.log(LOG_LEVEL_TRACE, msg, *args)
|
|
|
|
def load(self, state: UpdateState) -> None:
|
|
if __debug__:
|
|
self._trace(
|
|
"Loading MessageBox with state = %r",
|
|
state,
|
|
)
|
|
|
|
deadline = next_updates_deadline()
|
|
|
|
self.map.clear()
|
|
if state.pts != NO_SEQ:
|
|
self.map[ENTRY_ACCOUNT] = State(pts=state.pts, deadline=deadline)
|
|
if state.qts != NO_SEQ:
|
|
self.map[ENTRY_SECRET] = State(pts=state.qts, deadline=deadline)
|
|
self.map.update(
|
|
(s.id, State(pts=s.pts, deadline=deadline)) for s in state.channels
|
|
)
|
|
|
|
self.date = datetime.datetime.fromtimestamp(
|
|
state.date, tz=datetime.timezone.utc
|
|
)
|
|
self.seq = state.seq
|
|
self.possible_gaps.clear()
|
|
self.getting_diff_for.clear()
|
|
self.next_deadline = ENTRY_ACCOUNT
|
|
|
|
def session_state(self) -> UpdateState:
|
|
return UpdateState(
|
|
pts=self.map[ENTRY_ACCOUNT].pts if ENTRY_ACCOUNT in self.map else NO_PTS,
|
|
qts=self.map[ENTRY_SECRET].pts if ENTRY_SECRET in self.map else NO_PTS,
|
|
date=int(self.date.timestamp()),
|
|
seq=self.seq,
|
|
channels=[
|
|
ChannelState(id=int(entry), pts=state.pts)
|
|
for entry, state in self.map.items()
|
|
if entry not in (ENTRY_ACCOUNT, ENTRY_SECRET)
|
|
],
|
|
)
|
|
|
|
def is_empty(self) -> bool:
|
|
return (self.map.get(ENTRY_ACCOUNT) or NO_PTS) == NO_PTS
|
|
|
|
def check_deadlines(self) -> float:
|
|
now = asyncio.get_running_loop().time()
|
|
|
|
if self.getting_diff_for:
|
|
return now
|
|
|
|
default_deadline = next_updates_deadline()
|
|
|
|
if self.possible_gaps:
|
|
deadline = min(
|
|
default_deadline, *(gap.deadline for gap in self.possible_gaps.values())
|
|
)
|
|
elif self.next_deadline in self.map:
|
|
deadline = min(default_deadline, self.map[self.next_deadline].deadline)
|
|
else:
|
|
deadline = default_deadline
|
|
|
|
if now >= deadline:
|
|
self.getting_diff_for.update(
|
|
entry
|
|
for entry, gap in self.possible_gaps.items()
|
|
if now >= gap.deadline
|
|
)
|
|
self.getting_diff_for.update(
|
|
entry for entry, state in self.map.items() if now >= state.deadline
|
|
)
|
|
|
|
if __debug__:
|
|
self._trace(
|
|
"Deadlines met, now getting diff for %r", self.getting_diff_for
|
|
)
|
|
|
|
for entry in self.getting_diff_for:
|
|
self.possible_gaps.pop(entry, None)
|
|
|
|
return deadline
|
|
|
|
def reset_deadlines(self, entries: Set[Entry], deadline: float) -> None:
|
|
if not entries:
|
|
return
|
|
|
|
for entry in entries:
|
|
if entry not in self.map:
|
|
raise RuntimeError(
|
|
"Called reset_deadline on an entry for which we do not have state"
|
|
)
|
|
self.map[entry].deadline = deadline
|
|
|
|
if self.next_deadline in entries:
|
|
self.next_deadline = min(
|
|
self.map.items(), key=lambda entry_state: entry_state[1].deadline
|
|
)[0]
|
|
elif (
|
|
self.next_deadline in self.map
|
|
and deadline < self.map[self.next_deadline].deadline
|
|
):
|
|
self.next_deadline = entry
|
|
|
|
def reset_channel_deadline(self, channel_id: int, timeout: Optional[float]) -> None:
|
|
self.reset_deadlines(
|
|
{channel_id},
|
|
asyncio.get_running_loop().time() + (timeout or NO_UPDATES_TIMEOUT),
|
|
)
|
|
|
|
def set_state(self, state: abcs.updates.State) -> None:
|
|
if __debug__:
|
|
self._trace("Setting state %s", state)
|
|
|
|
deadline = next_updates_deadline()
|
|
assert isinstance(state, types.updates.State)
|
|
self.map[ENTRY_ACCOUNT] = State(state.pts, deadline)
|
|
self.map[ENTRY_SECRET] = State(state.qts, deadline)
|
|
self.date = datetime.datetime.fromtimestamp(
|
|
state.date, tz=datetime.timezone.utc
|
|
)
|
|
self.seq = state.seq
|
|
|
|
def try_set_channel_state(self, id: int, pts: int) -> None:
|
|
if __debug__:
|
|
self._trace("Trying to set channel state for %r: %r", id, pts)
|
|
|
|
if id not in self.map:
|
|
self.map[id] = State(pts=pts, deadline=next_updates_deadline())
|
|
|
|
def try_begin_get_diff(self, entry: Entry, reason: str) -> None:
|
|
if entry not in self.map:
|
|
if entry in self.possible_gaps:
|
|
raise RuntimeError(
|
|
"Should not have a possible_gap for an entry not in the state map"
|
|
)
|
|
return
|
|
|
|
if __debug__:
|
|
self._trace("Marking %r as needing difference because %s", entry, reason)
|
|
self.getting_diff_for.add(entry)
|
|
self.possible_gaps.pop(entry, None)
|
|
|
|
def end_get_diff(self, entry: Entry) -> None:
|
|
try:
|
|
self.getting_diff_for.remove(entry)
|
|
except KeyError:
|
|
raise RuntimeError(
|
|
"Called end_get_diff on an entry which was not getting diff for"
|
|
)
|
|
|
|
self.reset_deadlines({entry}, next_updates_deadline())
|
|
assert (
|
|
entry not in self.possible_gaps
|
|
), "gaps shouldn't be created while getting difference"
|
|
|
|
def ensure_known_peer_hashes(
|
|
self,
|
|
updates: abcs.Updates,
|
|
chat_hashes: ChatHashCache,
|
|
) -> None:
|
|
if not chat_hashes.extend_from_updates(updates):
|
|
can_recover = (
|
|
not isinstance(updates, types.UpdateShort)
|
|
or pts_info_from_update(updates.update) is not None
|
|
)
|
|
if can_recover:
|
|
self.try_begin_get_diff(ENTRY_ACCOUNT, "missing hash")
|
|
raise Gap
|
|
|
|
# https://core.telegram.org/api/updates
|
|
def process_updates(
|
|
self,
|
|
updates: abcs.Updates,
|
|
chat_hashes: ChatHashCache,
|
|
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
|
result: List[abcs.Update] = []
|
|
combined = adapt(updates, chat_hashes)
|
|
|
|
if __debug__:
|
|
self._trace(
|
|
"Processing updates with seq = %r, seq_start = %r, date = %r: %s",
|
|
combined.seq,
|
|
combined.seq_start,
|
|
combined.date,
|
|
updates,
|
|
)
|
|
|
|
if combined.seq_start != NO_SEQ:
|
|
if self.seq + 1 > combined.seq_start:
|
|
if __debug__:
|
|
self._trace(
|
|
"Skipping updates as they should have already been handled"
|
|
)
|
|
return result, combined.users, combined.chats
|
|
elif self.seq + 1 < combined.seq_start:
|
|
self.try_begin_get_diff(ENTRY_ACCOUNT, "detected gap")
|
|
raise Gap
|
|
|
|
def update_sort_key(update: abcs.Update) -> int:
|
|
pts = pts_info_from_update(update)
|
|
return pts.pts - pts.pts_count if pts else 0
|
|
|
|
sorted_updates = list(sorted(combined.updates, key=update_sort_key))
|
|
|
|
any_pts_applied = False
|
|
reset_deadlines_for = set()
|
|
for update in sorted_updates:
|
|
entry, applied = self.apply_pts_info(update)
|
|
if entry is not None:
|
|
reset_deadlines_for.add(entry)
|
|
if applied is not None:
|
|
result.append(applied)
|
|
any_pts_applied |= entry is not None
|
|
|
|
self.reset_deadlines(reset_deadlines_for, next_updates_deadline())
|
|
|
|
if any_pts_applied:
|
|
if __debug__:
|
|
self._trace("Updating seq as local pts was updated too")
|
|
self.date = datetime.datetime.fromtimestamp(
|
|
combined.date, tz=datetime.timezone.utc
|
|
)
|
|
if combined.seq != NO_SEQ:
|
|
self.seq = combined.seq
|
|
|
|
if self.possible_gaps:
|
|
if __debug__:
|
|
self._trace(
|
|
"Trying to re-apply %r possible gaps", len(self.possible_gaps)
|
|
)
|
|
|
|
for key in list(self.possible_gaps.keys()):
|
|
self.possible_gaps[key].updates.sort(key=update_sort_key)
|
|
|
|
for _ in range(len(self.possible_gaps[key].updates)):
|
|
update = self.possible_gaps[key].updates.pop(0)
|
|
_, applied = self.apply_pts_info(update)
|
|
if applied is not None:
|
|
result.append(applied)
|
|
if __debug__:
|
|
self._trace(
|
|
"Resolved gap with %r: %s",
|
|
pts_info_from_update(applied),
|
|
applied,
|
|
)
|
|
|
|
self.possible_gaps = {
|
|
entry: gap for entry, gap in self.possible_gaps.items() if gap.updates
|
|
}
|
|
|
|
return result, combined.users, combined.chats
|
|
|
|
def apply_pts_info(
|
|
self,
|
|
update: abcs.Update,
|
|
) -> Tuple[Optional[Entry], Optional[abcs.Update]]:
|
|
if isinstance(update, types.UpdateChannelTooLong):
|
|
self.try_begin_get_diff(update.channel_id, "received updateChannelTooLong")
|
|
return None, None
|
|
|
|
pts = pts_info_from_update(update)
|
|
if not pts:
|
|
if __debug__:
|
|
self._trace(
|
|
"No pts in update, so it can be applied in any order: %s", update
|
|
)
|
|
return None, update
|
|
|
|
if pts.entry in self.getting_diff_for:
|
|
if __debug__:
|
|
self._trace(
|
|
"Skipping update with %r as its difference is being fetched", pts
|
|
)
|
|
return pts.entry, None
|
|
|
|
if state := self.map.get(pts.entry):
|
|
local_pts = state.pts
|
|
if local_pts + pts.pts_count > pts.pts:
|
|
if __debug__:
|
|
self._trace(
|
|
"Skipping update since local pts %r > %r: %s",
|
|
local_pts,
|
|
pts,
|
|
update,
|
|
)
|
|
return pts.entry, None
|
|
elif local_pts + pts.pts_count < pts.pts:
|
|
# TODO store chats too?
|
|
if __debug__:
|
|
self._trace(
|
|
"Possible gap since local pts %r < %r: %s",
|
|
local_pts,
|
|
pts,
|
|
update,
|
|
)
|
|
if pts.entry not in self.possible_gaps:
|
|
self.possible_gaps[pts.entry] = PossibleGap(
|
|
deadline=asyncio.get_running_loop().time()
|
|
+ POSSIBLE_GAP_TIMEOUT,
|
|
updates=[],
|
|
)
|
|
|
|
self.possible_gaps[pts.entry].updates.append(update)
|
|
return pts.entry, None
|
|
else:
|
|
if __debug__:
|
|
self._trace(
|
|
"Applying update pts since local pts %r = %r: %s",
|
|
local_pts,
|
|
pts,
|
|
update,
|
|
)
|
|
|
|
if pts.entry not in self.map:
|
|
self.map[pts.entry] = State(
|
|
pts=0,
|
|
deadline=next_updates_deadline(),
|
|
)
|
|
self.map[pts.entry].pts = pts.pts
|
|
|
|
return pts.entry, update
|
|
|
|
def get_difference(self) -> Optional[Request[abcs.updates.Difference]]:
|
|
for entry in (ENTRY_ACCOUNT, ENTRY_SECRET):
|
|
if entry in self.getting_diff_for:
|
|
if entry not in self.map:
|
|
raise RuntimeError(
|
|
"Should not try to get difference for an entry without known state"
|
|
)
|
|
|
|
gd = functions.updates.get_difference(
|
|
pts=self.map[ENTRY_ACCOUNT].pts,
|
|
pts_limit=None,
|
|
pts_total_limit=None,
|
|
date=int(self.date.timestamp()),
|
|
qts=self.map[ENTRY_SECRET].pts
|
|
if ENTRY_SECRET in self.map
|
|
else NO_SEQ,
|
|
qts_limit=None,
|
|
)
|
|
if __debug__:
|
|
self._trace("Requesting account difference %s", gd)
|
|
return gd
|
|
|
|
return None
|
|
|
|
def apply_difference(
|
|
self,
|
|
diff: abcs.updates.Difference,
|
|
chat_hashes: ChatHashCache,
|
|
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
|
if __debug__:
|
|
self._trace("Applying account difference %s", diff)
|
|
|
|
finish: bool
|
|
result: Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]
|
|
if isinstance(diff, types.updates.DifferenceEmpty):
|
|
finish = True
|
|
self.date = datetime.datetime.fromtimestamp(
|
|
diff.date, tz=datetime.timezone.utc
|
|
)
|
|
self.seq = diff.seq
|
|
result = [], [], []
|
|
elif isinstance(diff, types.updates.Difference):
|
|
chat_hashes.extend(diff.users, diff.chats)
|
|
finish = True
|
|
result = self.apply_difference_type(diff, chat_hashes)
|
|
elif isinstance(diff, types.updates.DifferenceSlice):
|
|
chat_hashes.extend(diff.users, diff.chats)
|
|
finish = False
|
|
result = self.apply_difference_type(
|
|
types.updates.Difference(
|
|
new_messages=diff.new_messages,
|
|
new_encrypted_messages=diff.new_encrypted_messages,
|
|
other_updates=diff.other_updates,
|
|
chats=diff.chats,
|
|
users=diff.users,
|
|
state=diff.intermediate_state,
|
|
),
|
|
chat_hashes,
|
|
)
|
|
elif isinstance(diff, types.updates.DifferenceTooLong):
|
|
finish = True
|
|
self.map[ENTRY_ACCOUNT].pts = diff.pts
|
|
result = [], [], []
|
|
else:
|
|
raise RuntimeError("unexpected case")
|
|
|
|
if finish:
|
|
account = ENTRY_ACCOUNT in self.getting_diff_for
|
|
secret = ENTRY_SECRET in self.getting_diff_for
|
|
|
|
if not account and not secret:
|
|
raise RuntimeError(
|
|
"Should not be applying the difference when neither account or secret was diff was active"
|
|
)
|
|
|
|
if account:
|
|
self.end_get_diff(ENTRY_ACCOUNT)
|
|
if secret:
|
|
self.end_get_diff(ENTRY_SECRET)
|
|
|
|
return result
|
|
|
|
def apply_difference_type(
|
|
self,
|
|
diff: types.updates.Difference,
|
|
chat_hashes: ChatHashCache,
|
|
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
|
state = diff.state
|
|
assert isinstance(state, types.updates.State)
|
|
self.map[ENTRY_ACCOUNT].pts = state.pts
|
|
self.map[ENTRY_SECRET].pts = state.qts
|
|
self.date = datetime.datetime.fromtimestamp(
|
|
state.date, tz=datetime.timezone.utc
|
|
)
|
|
self.seq = state.seq
|
|
|
|
updates, users, chats = self.process_updates(
|
|
types.Updates(
|
|
updates=diff.other_updates,
|
|
users=diff.users,
|
|
chats=diff.chats,
|
|
date=int(epoch().timestamp()),
|
|
seq=NO_SEQ,
|
|
),
|
|
chat_hashes,
|
|
)
|
|
|
|
updates.extend(
|
|
types.UpdateNewMessage(
|
|
message=m,
|
|
pts=NO_PTS,
|
|
pts_count=0,
|
|
)
|
|
for m in diff.new_messages
|
|
)
|
|
updates.extend(
|
|
types.UpdateNewEncryptedMessage(
|
|
message=m,
|
|
qts=NO_PTS,
|
|
)
|
|
for m in diff.new_encrypted_messages
|
|
)
|
|
|
|
return updates, users, chats
|
|
|
|
def get_channel_difference(
|
|
self,
|
|
chat_hashes: ChatHashCache,
|
|
) -> Optional[Request[abcs.updates.ChannelDifference]]:
|
|
for entry in self.getting_diff_for:
|
|
if entry not in (ENTRY_ACCOUNT, ENTRY_SECRET):
|
|
id = int(entry)
|
|
break
|
|
else:
|
|
return None
|
|
|
|
packed = chat_hashes.get(id)
|
|
if packed:
|
|
assert packed.access_hash is not None
|
|
channel = types.InputChannel(
|
|
channel_id=packed.id,
|
|
access_hash=packed.access_hash,
|
|
)
|
|
if state := self.map.get(entry):
|
|
gd = functions.updates.get_channel_difference(
|
|
force=False,
|
|
channel=channel,
|
|
filter=types.ChannelMessagesFilterEmpty(),
|
|
pts=state.pts,
|
|
limit=BOT_CHANNEL_DIFF_LIMIT
|
|
if chat_hashes.is_self_bot
|
|
else USER_CHANNEL_DIFF_LIMIT,
|
|
)
|
|
if __debug__:
|
|
self._trace("Requesting channel difference %s", gd)
|
|
return gd
|
|
else:
|
|
raise RuntimeError(
|
|
"Should not try to get difference for an entry without known state"
|
|
)
|
|
else:
|
|
self.end_get_diff(entry)
|
|
self.map.pop(entry, None)
|
|
return None
|
|
|
|
def apply_channel_difference(
|
|
self,
|
|
channel_id: int,
|
|
diff: abcs.updates.ChannelDifference,
|
|
chat_hashes: ChatHashCache,
|
|
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
|
entry: Entry = channel_id
|
|
if __debug__:
|
|
self._trace("Applying channel difference for %r: %s", entry, diff)
|
|
|
|
self.possible_gaps.pop(entry, None)
|
|
|
|
if isinstance(diff, types.updates.ChannelDifferenceEmpty):
|
|
assert diff.final
|
|
self.end_get_diff(entry)
|
|
self.map[entry].pts = diff.pts
|
|
return [], [], []
|
|
elif isinstance(diff, types.updates.ChannelDifferenceTooLong):
|
|
chat_hashes.extend(diff.users, diff.chats)
|
|
|
|
assert diff.final
|
|
if isinstance(diff.dialog, types.Dialog):
|
|
assert diff.dialog.pts is not None
|
|
self.map[entry].pts = diff.dialog.pts
|
|
else:
|
|
raise RuntimeError("unexpected type on ChannelDifferenceTooLong")
|
|
self.reset_channel_deadline(channel_id, diff.timeout)
|
|
return [], [], []
|
|
elif isinstance(diff, types.updates.ChannelDifference):
|
|
chat_hashes.extend(diff.users, diff.chats)
|
|
|
|
if diff.final:
|
|
self.end_get_diff(entry)
|
|
|
|
self.map[entry].pts = diff.pts
|
|
updates, users, chats = self.process_updates(
|
|
types.Updates(
|
|
updates=diff.other_updates,
|
|
users=diff.users,
|
|
chats=diff.chats,
|
|
date=int(epoch().timestamp()),
|
|
seq=NO_SEQ,
|
|
),
|
|
chat_hashes,
|
|
)
|
|
|
|
updates.extend(
|
|
types.UpdateNewChannelMessage(
|
|
message=m,
|
|
pts=NO_PTS,
|
|
pts_count=0,
|
|
)
|
|
for m in diff.new_messages
|
|
)
|
|
self.reset_channel_deadline(channel_id, None)
|
|
|
|
return updates, users, chats
|
|
else:
|
|
raise RuntimeError("unexpected case")
|
|
|
|
def end_channel_difference(
|
|
self, channel_id: int, reason: PrematureEndReason
|
|
) -> None:
|
|
entry: Entry = channel_id
|
|
if __debug__:
|
|
self._trace("Ending channel difference for %r because %s", entry, reason)
|
|
|
|
if reason == PrematureEndReason.TEMPORARY_SERVER_ISSUES:
|
|
self.possible_gaps.pop(entry, None)
|
|
self.end_get_diff(entry)
|
|
elif reason == PrematureEndReason.BANNED:
|
|
self.possible_gaps.pop(entry, None)
|
|
self.end_get_diff(entry)
|
|
del self.map[entry]
|
|
else:
|
|
raise RuntimeError("Unknown reason to end channel difference")
|