Rename more subpackages and modules

This commit is contained in:
Lonami Exo
2021-09-11 17:48:23 +02:00
parent 66ef553adc
commit a901d43a6d
61 changed files with 69 additions and 48 deletions

View File

@@ -0,0 +1,6 @@
"""
Several extensions Python is missing, such as a proper class to handle a TCP
communication with support for cancelling the operation, and a utility class
to read arbitrary binary data in a more comfortable way, with int/strings/etc.
"""
from .binaryreader import BinaryReader

View File

@@ -0,0 +1,185 @@
"""
This module contains the BinaryReader utility class.
"""
import os
import time
from datetime import datetime, timezone, timedelta
from io import BytesIO
from struct import unpack
from ..errors import TypeNotFoundError
from ..tl.alltlobjects import tlobjects
from ..tl.core import core_objects
_EPOCH_NAIVE = datetime(*time.gmtime(0)[:6])
_EPOCH = _EPOCH_NAIVE.replace(tzinfo=timezone.utc)
class BinaryReader:
"""
Small utility class to read binary data.
"""
def __init__(self, data):
self.stream = BytesIO(data)
self._last = None # Should come in handy to spot -404 errors
# region Reading
# "All numbers are written as little endian."
# https://core.telegram.org/mtproto
def read_byte(self):
"""Reads a single byte value."""
return self.read(1)[0]
def read_int(self, signed=True):
"""Reads an integer (4 bytes) value."""
return int.from_bytes(self.read(4), byteorder='little', signed=signed)
def read_long(self, signed=True):
"""Reads a long integer (8 bytes) value."""
return int.from_bytes(self.read(8), byteorder='little', signed=signed)
def read_float(self):
"""Reads a real floating point (4 bytes) value."""
return unpack('<f', self.read(4))[0]
def read_double(self):
"""Reads a real floating point (8 bytes) value."""
return unpack('<d', self.read(8))[0]
def read_large_int(self, bits, signed=True):
"""Reads a n-bits long integer value."""
return int.from_bytes(
self.read(bits // 8), byteorder='little', signed=signed)
def read(self, length=-1):
"""Read the given amount of bytes, or -1 to read all remaining."""
result = self.stream.read(length)
if (length >= 0) and (len(result) != length):
raise BufferError(
'No more data left to read (need {}, got {}: {}); last read {}'
.format(length, len(result), repr(result), repr(self._last))
)
self._last = result
return result
def get_bytes(self):
"""Gets the byte array representing the current buffer as a whole."""
return self.stream.getvalue()
# endregion
# region Telegram custom reading
def tgread_bytes(self):
"""
Reads a Telegram-encoded byte array, without the need of
specifying its length.
"""
first_byte = self.read_byte()
if first_byte == 254:
length = self.read_byte() | (self.read_byte() << 8) | (
self.read_byte() << 16)
padding = length % 4
else:
length = first_byte
padding = (length + 1) % 4
data = self.read(length)
if padding > 0:
padding = 4 - padding
self.read(padding)
return data
def tgread_string(self):
"""Reads a Telegram-encoded string."""
return str(self.tgread_bytes(), encoding='utf-8', errors='replace')
def tgread_bool(self):
"""Reads a Telegram boolean value."""
value = self.read_int(signed=False)
if value == 0x997275b5: # boolTrue
return True
elif value == 0xbc799737: # boolFalse
return False
else:
raise RuntimeError('Invalid boolean code {}'.format(hex(value)))
def tgread_date(self):
"""Reads and converts Unix time (used by Telegram)
into a Python datetime object.
"""
value = self.read_int()
return _EPOCH + timedelta(seconds=value)
def tgread_object(self):
"""Reads a Telegram object."""
constructor_id = self.read_int(signed=False)
clazz = tlobjects.get(constructor_id, None)
if clazz is None:
# The class was None, but there's still a
# chance of it being a manually parsed value like bool!
value = constructor_id
if value == 0x997275b5: # boolTrue
return True
elif value == 0xbc799737: # boolFalse
return False
elif value == 0x1cb5c415: # Vector
return [self.tgread_object() for _ in range(self.read_int())]
clazz = core_objects.get(constructor_id, None)
if clazz is None:
# If there was still no luck, give up
self.seek(-4) # Go back
pos = self.tell_position()
error = TypeNotFoundError(constructor_id, self.read())
self.set_position(pos)
raise error
return clazz.from_reader(self)
def tgread_vector(self):
"""Reads a vector (a list) of Telegram objects."""
if 0x1cb5c415 != self.read_int(signed=False):
raise RuntimeError('Invalid constructor code, vector was expected')
count = self.read_int()
return [self.tgread_object() for _ in range(count)]
# endregion
def close(self):
"""Closes the reader, freeing the BytesIO stream."""
self.stream.close()
# region Position related
def tell_position(self):
"""Tells the current position on the stream."""
return self.stream.tell()
def set_position(self, position):
"""Sets the current position on the stream."""
self.stream.seek(position)
def seek(self, offset):
"""
Seeks the stream position given an offset from the current position.
The offset may be negative.
"""
self.stream.seek(offset, os.SEEK_CUR)
# endregion
# region with block
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
# endregion

View File

@@ -0,0 +1,147 @@
import inspect
import itertools
from . import utils
from .tl import types
# Which updates have the following fields?
_has_field = {
('user_id', int): [],
('chat_id', int): [],
('channel_id', int): [],
('peer', 'TypePeer'): [],
('peer', 'TypeDialogPeer'): [],
('message', 'TypeMessage'): [],
}
# Note: We don't bother checking for some rare:
# * `UpdateChatParticipantAdd.inviter_id` integer.
# * `UpdateNotifySettings.peer` dialog peer.
# * `UpdatePinnedDialogs.order` list of dialog peers.
# * `UpdateReadMessagesContents.messages` list of messages.
# * `UpdateChatParticipants.participants` list of participants.
#
# There are also some uninteresting `update.message` of type string.
def _fill():
for name in dir(types):
update = getattr(types, name)
if getattr(update, 'SUBCLASS_OF_ID', None) == 0x9f89304e:
cid = update.CONSTRUCTOR_ID
sig = inspect.signature(update.__init__)
for param in sig.parameters.values():
vec = _has_field.get((param.name, param.annotation))
if vec is not None:
vec.append(cid)
# Future-proof check: if the documentation format ever changes
# then we won't be able to pick the update types we are interested
# in, so we must make sure we have at least an update for each field
# which likely means we are doing it right.
if not all(_has_field.values()):
raise RuntimeError('FIXME: Did the init signature or updates change?')
# We use a function to avoid cluttering the globals (with name/update/cid/doc)
_fill()
class EntityCache:
"""
In-memory input entity cache, defaultdict-like behaviour.
"""
def add(self, entities):
"""
Adds the given entities to the cache, if they weren't saved before.
"""
if not utils.is_list_like(entities):
# Invariant: all "chats" and "users" are always iterables,
# and "user" never is (so we wrap it inside a list).
entities = itertools.chain(
getattr(entities, 'chats', []),
getattr(entities, 'users', []),
(hasattr(entities, 'user') and [entities.user]) or []
)
for entity in entities:
try:
pid = utils.get_peer_id(entity)
if pid not in self.__dict__:
# Note: `get_input_peer` already checks for `access_hash`
self.__dict__[pid] = utils.get_input_peer(entity)
except TypeError:
pass
def __getitem__(self, item):
"""
Gets the corresponding :tl:`InputPeer` for the given ID or peer,
or raises ``KeyError`` on any error (i.e. cannot be found).
"""
if not isinstance(item, int) or item < 0:
try:
return self.__dict__[utils.get_peer_id(item)]
except TypeError:
raise KeyError('Invalid key will not have entity') from None
for cls in (types.PeerUser, types.PeerChat, types.PeerChannel):
result = self.__dict__.get(utils.get_peer_id(cls(item)))
if result:
return result
raise KeyError('No cached entity for the given key')
def clear(self):
"""
Clear the entity cache.
"""
self.__dict__.clear()
def ensure_cached(
self,
update,
has_user_id=frozenset(_has_field[('user_id', int)]),
has_chat_id=frozenset(_has_field[('chat_id', int)]),
has_channel_id=frozenset(_has_field[('channel_id', int)]),
has_peer=frozenset(_has_field[('peer', 'TypePeer')] + _has_field[('peer', 'TypeDialogPeer')]),
has_message=frozenset(_has_field[('message', 'TypeMessage')])
):
"""
Ensures that all the relevant entities in the given update are cached.
"""
# This method is called pretty often and we want it to have the lowest
# overhead possible. For that, we avoid `isinstance` and constantly
# getting attributes out of `types.` by "caching" the constructor IDs
# in sets inside the arguments, and using local variables.
dct = self.__dict__
cid = update.CONSTRUCTOR_ID
if cid in has_user_id and \
update.user_id not in dct:
return False
if cid in has_chat_id and \
utils.get_peer_id(types.PeerChat(update.chat_id)) not in dct:
return False
if cid in has_channel_id and \
utils.get_peer_id(types.PeerChannel(update.channel_id)) not in dct:
return False
if cid in has_peer and \
utils.get_peer_id(update.peer) not in dct:
return False
if cid in has_message:
x = update.message
y = getattr(x, 'peer_id', None) # handle MessageEmpty
if y and utils.get_peer_id(y) not in dct:
return False
y = getattr(x, 'from_id', None)
if y and utils.get_peer_id(y) not in dct:
return False
# We don't quite worry about entities anywhere else.
# This is enough.
return True

363
telethon/_misc/helpers.py Normal file
View File

@@ -0,0 +1,363 @@
"""Various helpers not related to the Telegram API itself"""
import asyncio
import io
import enum
import os
import struct
import inspect
import logging
import functools
from pathlib import Path
from hashlib import sha1
class _EntityType(enum.Enum):
USER = 0
CHAT = 1
CHANNEL = 2
_log = logging.getLogger(__name__)
# region Multiple utilities
def generate_random_long(signed=True):
"""Generates a random long integer (8 bytes), which is optionally signed"""
return int.from_bytes(os.urandom(8), signed=signed, byteorder='little')
def ensure_parent_dir_exists(file_path):
"""Ensures that the parent directory exists"""
parent = os.path.dirname(file_path)
if parent:
os.makedirs(parent, exist_ok=True)
def add_surrogate(text):
return ''.join(
# SMP -> Surrogate Pairs (Telegram offsets are calculated with these).
# See https://en.wikipedia.org/wiki/Plane_(Unicode)#Overview for more.
''.join(chr(y) for y in struct.unpack('<HH', x.encode('utf-16le')))
if (0x10000 <= ord(x) <= 0x10FFFF) else x for x in text
)
def del_surrogate(text):
return text.encode('utf-16', 'surrogatepass').decode('utf-16')
def within_surrogate(text, index, *, length=None):
"""
`True` if ``index`` is within a surrogate (before and after it, not at!).
"""
if length is None:
length = len(text)
return (
1 < index < len(text) and # in bounds
'\ud800' <= text[index - 1] <= '\udfff' and # previous is
'\ud800' <= text[index] <= '\udfff' # current is
)
def strip_text(text, entities):
"""
Strips whitespace from the given text modifying the provided entities.
This assumes that there are no overlapping entities, that their length
is greater or equal to one, and that their length is not out of bounds.
"""
if not entities:
return text.strip()
while text and text[-1].isspace():
e = entities[-1]
if e.offset + e.length == len(text):
if e.length == 1:
del entities[-1]
if not entities:
return text.strip()
else:
e.length -= 1
text = text[:-1]
while text and text[0].isspace():
for i in reversed(range(len(entities))):
e = entities[i]
if e.offset != 0:
e.offset -= 1
continue
if e.length == 1:
del entities[0]
if not entities:
return text.lstrip()
else:
e.length -= 1
text = text[1:]
return text
def retry_range(retries, force_retry=True):
"""
Generates an integer sequence starting from 1. If `retries` is
not a zero or a positive integer value, the sequence will be
infinite, otherwise it will end at `retries + 1`.
"""
# We need at least one iteration even if the retries are 0
# when force_retry is True.
if force_retry and not (retries is None or retries < 0):
retries += 1
attempt = 0
while attempt != retries:
attempt += 1
yield attempt
async def _maybe_await(value):
if inspect.isawaitable(value):
return await value
else:
return value
async def _cancel(log, **tasks):
"""
Helper to cancel one or more tasks gracefully, logging exceptions.
"""
for name, task in tasks.items():
if not task:
continue
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
except RuntimeError:
# Probably: RuntimeError: await wasn't used with future
#
# See: https://github.com/python/cpython/blob/12d3061c7819a73d891dcce44327410eaf0e1bc2/Lib/asyncio/futures.py#L265
#
# Happens with _asyncio.Task instances (in "Task cancelling" state)
# trying to SIGINT the program right during initial connection, on
# _recv_loop coroutine (but we're creating its task explicitly with
# a loop, so how can it bug out like this?).
#
# Since we're aware of this error there's no point in logging it.
# *May* be https://bugs.python.org/issue37172
pass
except AssertionError as e:
# In Python 3.6, the above RuntimeError is an AssertionError
# See https://github.com/python/cpython/blob/7df32f844efed33ca781a016017eab7050263b90/Lib/asyncio/futures.py#L328
if e.args != ("yield from wasn't used with future",):
log.exception('Unhandled exception from %s after cancelling '
'%s (%s)', name, type(task), task)
except Exception:
log.exception('Unhandled exception from %s after cancelling '
'%s (%s)', name, type(task), task)
def _entity_type(entity):
# This could be a `utils` method that just ran a few `isinstance` on
# `utils.get_peer(...)`'s result. However, there are *a lot* of auto
# casts going on, plenty of calls and temporary short-lived objects.
#
# So we just check if a string is in the class name.
# Still, assert that it's the right type to not return false results.
try:
if entity.SUBCLASS_OF_ID not in (
0x2d45687, # crc32(b'Peer')
0xc91c90b6, # crc32(b'InputPeer')
0xe669bf46, # crc32(b'InputUser')
0x40f202fd, # crc32(b'InputChannel')
0x2da17977, # crc32(b'User')
0xc5af5d94, # crc32(b'Chat')
0x1f4661b9, # crc32(b'UserFull')
0xd49a2697, # crc32(b'ChatFull')
):
raise TypeError('{} does not have any entity type'.format(entity))
except AttributeError:
raise TypeError('{} is not a TLObject, cannot determine entity type'.format(entity))
name = entity.__class__.__name__
if 'User' in name:
return _EntityType.USER
elif 'Chat' in name:
return _EntityType.CHAT
elif 'Channel' in name:
return _EntityType.CHANNEL
elif 'Self' in name:
return _EntityType.USER
# 'Empty' in name or not found, we don't care, not a valid entity.
raise TypeError('{} does not have any entity type'.format(entity))
# endregion
# region Cryptographic related utils
def generate_key_data_from_nonce(server_nonce, new_nonce):
"""Generates the key data corresponding to the given nonce"""
server_nonce = server_nonce.to_bytes(16, 'little', signed=True)
new_nonce = new_nonce.to_bytes(32, 'little', signed=True)
hash1 = sha1(new_nonce + server_nonce).digest()
hash2 = sha1(server_nonce + new_nonce).digest()
hash3 = sha1(new_nonce + new_nonce).digest()
key = hash1 + hash2[:12]
iv = hash2[12:20] + hash3 + new_nonce[:4]
return key, iv
# endregion
# region Custom Classes
class TotalList(list):
"""
A list with an extra `total` property, which may not match its `len`
since the total represents the total amount of items *available*
somewhere else, not the items *in this list*.
Examples:
.. code-block:: python
# Telethon returns these lists in some cases (for example,
# only when a chunk is returned, but the "total" count
# is available).
result = await client.get_messages(chat, limit=10)
print(result.total) # large number
print(len(result)) # 10
print(result[0]) # latest message
for x in result: # show the 10 messages
print(x.text)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.total = 0
def __str__(self):
return '[{}, total={}]'.format(
', '.join(str(x) for x in self), self.total)
def __repr__(self):
return '[{}, total={}]'.format(
', '.join(repr(x) for x in self), self.total)
class _FileStream(io.IOBase):
"""
Proxy around things that represent a file and need to be used as streams
which may or not need to be closed.
This will handle `pathlib.Path`, `str` paths, in-memory `bytes`, and
anything IO-like (including `aiofiles`).
It also provides access to the name and file size (also necessary).
"""
def __init__(self, file, *, file_size=None):
if isinstance(file, Path):
file = str(file.absolute())
self._file = file
self._name = None
self._size = file_size
self._stream = None
self._close_stream = None
async def __aenter__(self):
if isinstance(self._file, str):
self._name = os.path.basename(self._file)
self._size = os.path.getsize(self._file)
self._stream = open(self._file, 'rb')
self._close_stream = True
elif isinstance(self._file, bytes):
self._size = len(self._file)
self._stream = io.BytesIO(self._file)
self._close_stream = True
elif not callable(getattr(self._file, 'read', None)):
raise TypeError('file description should have a `read` method')
elif self._size is not None:
self._name = getattr(self._file, 'name', None)
self._stream = self._file
self._close_stream = False
else:
if callable(getattr(self._file, 'seekable', None)):
seekable = await _maybe_await(self._file.seekable())
else:
seekable = False
if seekable:
pos = await _maybe_await(self._file.tell())
await _maybe_await(self._file.seek(0, os.SEEK_END))
self._size = await _maybe_await(self._file.tell())
await _maybe_await(self._file.seek(pos, os.SEEK_SET))
self._stream = self._file
self._close_stream = False
else:
_log.warning(
'Could not determine file size beforehand so the entire '
'file will be read in-memory')
data = await _maybe_await(self._file.read())
self._size = len(data)
self._stream = io.BytesIO(data)
self._close_stream = True
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self._close_stream and self._stream:
await _maybe_await(self._stream.close())
@property
def file_size(self):
return self._size
@property
def name(self):
return self._name
# Proxy all the methods. Doesn't need to be readable (makes multiline edits easier)
def read(self, *args, **kwargs): return self._stream.read(*args, **kwargs)
def readinto(self, *args, **kwargs): return self._stream.readinto(*args, **kwargs)
def write(self, *args, **kwargs): return self._stream.write(*args, **kwargs)
def fileno(self, *args, **kwargs): return self._stream.fileno(*args, **kwargs)
def flush(self, *args, **kwargs): return self._stream.flush(*args, **kwargs)
def isatty(self, *args, **kwargs): return self._stream.isatty(*args, **kwargs)
def readable(self, *args, **kwargs): return self._stream.readable(*args, **kwargs)
def readline(self, *args, **kwargs): return self._stream.readline(*args, **kwargs)
def readlines(self, *args, **kwargs): return self._stream.readlines(*args, **kwargs)
def seek(self, *args, **kwargs): return self._stream.seek(*args, **kwargs)
def seekable(self, *args, **kwargs): return self._stream.seekable(*args, **kwargs)
def tell(self, *args, **kwargs): return self._stream.tell(*args, **kwargs)
def truncate(self, *args, **kwargs): return self._stream.truncate(*args, **kwargs)
def writable(self, *args, **kwargs): return self._stream.writable(*args, **kwargs)
def writelines(self, *args, **kwargs): return self._stream.writelines(*args, **kwargs)
# close is special because it will be called by __del__ but we do NOT
# want to close the file unless we have to (we're just a wrapper).
# Instead, we do nothing (we should be used through the decorator which
# has its own mechanism to close the file correctly).
def close(self, *args, **kwargs):
pass
# endregion

67
telethon/_misc/hints.py Normal file
View File

@@ -0,0 +1,67 @@
import datetime
import typing
from . import helpers
from .tl import types, custom
Phone = str
Username = str
PeerID = int
Entity = typing.Union[types.User, types.Chat, types.Channel]
FullEntity = typing.Union[types.UserFull, types.messages.ChatFull, types.ChatFull, types.ChannelFull]
EntityLike = typing.Union[
Phone,
Username,
PeerID,
types.TypePeer,
types.TypeInputPeer,
Entity,
FullEntity
]
EntitiesLike = typing.Union[EntityLike, typing.Sequence[EntityLike]]
ButtonLike = typing.Union[types.TypeKeyboardButton, custom.Button]
MarkupLike = typing.Union[
types.TypeReplyMarkup,
ButtonLike,
typing.Sequence[ButtonLike],
typing.Sequence[typing.Sequence[ButtonLike]]
]
TotalList = helpers.TotalList
DateLike = typing.Optional[typing.Union[float, datetime.datetime, datetime.date, datetime.timedelta]]
LocalPath = str
ExternalUrl = str
BotFileID = str
FileLike = typing.Union[
LocalPath,
ExternalUrl,
BotFileID,
bytes,
typing.BinaryIO,
types.TypeMessageMedia,
types.TypeInputFile,
types.TypeInputFileLocation
]
# Can't use `typing.Type` in Python 3.5.2
# See https://github.com/python/typing/issues/266
try:
OutFileLike = typing.Union[
str,
typing.Type[bytes],
typing.BinaryIO
]
except TypeError:
OutFileLike = typing.Union[
str,
typing.BinaryIO
]
MessageLike = typing.Union[str, types.Message]
MessageIDLike = typing.Union[int, types.Message, types.TypeInputMessage]
ProgressCallback = typing.Callable[[int, int], None]

229
telethon/_misc/html.py Normal file
View File

@@ -0,0 +1,229 @@
"""
Simple HTML -> Telegram entity parser.
"""
import struct
from collections import deque
from html import escape
from html.parser import HTMLParser
from typing import Iterable, Optional, Tuple, List
from .. import helpers
from ..tl.types import (
MessageEntityBold, MessageEntityItalic, MessageEntityCode,
MessageEntityPre, MessageEntityEmail, MessageEntityUrl,
MessageEntityTextUrl, MessageEntityMentionName,
MessageEntityUnderline, MessageEntityStrike, MessageEntityBlockquote,
TypeMessageEntity
)
# Helpers from markdown.py
def _add_surrogate(text):
return ''.join(
''.join(chr(y) for y in struct.unpack('<HH', x.encode('utf-16le')))
if (0x10000 <= ord(x) <= 0x10FFFF) else x for x in text
)
def _del_surrogate(text):
return text.encode('utf-16', 'surrogatepass').decode('utf-16')
class HTMLToTelegramParser(HTMLParser):
def __init__(self):
super().__init__()
self.text = ''
self.entities = []
self._building_entities = {}
self._open_tags = deque()
self._open_tags_meta = deque()
def handle_starttag(self, tag, attrs):
self._open_tags.appendleft(tag)
self._open_tags_meta.appendleft(None)
attrs = dict(attrs)
EntityType = None
args = {}
if tag == 'strong' or tag == 'b':
EntityType = MessageEntityBold
elif tag == 'em' or tag == 'i':
EntityType = MessageEntityItalic
elif tag == 'u':
EntityType = MessageEntityUnderline
elif tag == 'del' or tag == 's':
EntityType = MessageEntityStrike
elif tag == 'blockquote':
EntityType = MessageEntityBlockquote
elif tag == 'code':
try:
# If we're in the middle of a <pre> tag, this <code> tag is
# probably intended for syntax highlighting.
#
# Syntax highlighting is set with
# <code class='language-...'>codeblock</code>
# inside <pre> tags
pre = self._building_entities['pre']
try:
pre.language = attrs['class'][len('language-'):]
except KeyError:
pass
except KeyError:
EntityType = MessageEntityCode
elif tag == 'pre':
EntityType = MessageEntityPre
args['language'] = ''
elif tag == 'a':
try:
url = attrs['href']
except KeyError:
return
if url.startswith('mailto:'):
url = url[len('mailto:'):]
EntityType = MessageEntityEmail
else:
if self.get_starttag_text() == url:
EntityType = MessageEntityUrl
else:
EntityType = MessageEntityTextUrl
args['url'] = url
url = None
self._open_tags_meta.popleft()
self._open_tags_meta.appendleft(url)
if EntityType and tag not in self._building_entities:
self._building_entities[tag] = EntityType(
offset=len(self.text),
# The length will be determined when closing the tag.
length=0,
**args)
def handle_data(self, text):
previous_tag = self._open_tags[0] if len(self._open_tags) > 0 else ''
if previous_tag == 'a':
url = self._open_tags_meta[0]
if url:
text = url
for tag, entity in self._building_entities.items():
entity.length += len(text)
self.text += text
def handle_endtag(self, tag):
try:
self._open_tags.popleft()
self._open_tags_meta.popleft()
except IndexError:
pass
entity = self._building_entities.pop(tag, None)
if entity:
self.entities.append(entity)
def parse(html: str) -> Tuple[str, List[TypeMessageEntity]]:
"""
Parses the given HTML message and returns its stripped representation
plus a list of the MessageEntity's that were found.
:param html: the message with HTML to be parsed.
:return: a tuple consisting of (clean message, [message entities]).
"""
if not html:
return html, []
parser = HTMLToTelegramParser()
parser.feed(_add_surrogate(html))
text = helpers.strip_text(parser.text, parser.entities)
return _del_surrogate(text), parser.entities
def unparse(text: str, entities: Iterable[TypeMessageEntity], _offset: int = 0,
_length: Optional[int] = None) -> str:
"""
Performs the reverse operation to .parse(), effectively returning HTML
given a normal text and its MessageEntity's.
:param text: the text to be reconverted into HTML.
:param entities: the MessageEntity's applied to the text.
:return: a HTML representation of the combination of both inputs.
"""
if not text:
return text
elif not entities:
return escape(text)
text = _add_surrogate(text)
if _length is None:
_length = len(text)
html = []
last_offset = 0
for i, entity in enumerate(entities):
if entity.offset >= _offset + _length:
break
relative_offset = entity.offset - _offset
if relative_offset > last_offset:
html.append(escape(text[last_offset:relative_offset]))
elif relative_offset < last_offset:
continue
skip_entity = False
length = entity.length
# If we are in the middle of a surrogate nudge the position by +1.
# Otherwise we would end up with malformed text and fail to encode.
# For example of bad input: "Hi \ud83d\ude1c"
# https://en.wikipedia.org/wiki/UTF-16#U+010000_to_U+10FFFF
while helpers.within_surrogate(text, relative_offset, length=_length):
relative_offset += 1
while helpers.within_surrogate(text, relative_offset + length, length=_length):
length += 1
entity_text = unparse(text=text[relative_offset:relative_offset + length],
entities=entities[i + 1:],
_offset=entity.offset, _length=length)
entity_type = type(entity)
if entity_type == MessageEntityBold:
html.append('<strong>{}</strong>'.format(entity_text))
elif entity_type == MessageEntityItalic:
html.append('<em>{}</em>'.format(entity_text))
elif entity_type == MessageEntityCode:
html.append('<code>{}</code>'.format(entity_text))
elif entity_type == MessageEntityUnderline:
html.append('<u>{}</u>'.format(entity_text))
elif entity_type == MessageEntityStrike:
html.append('<del>{}</del>'.format(entity_text))
elif entity_type == MessageEntityBlockquote:
html.append('<blockquote>{}</blockquote>'.format(entity_text))
elif entity_type == MessageEntityPre:
if entity.language:
html.append(
"<pre>\n"
" <code class='language-{}'>\n"
" {}\n"
" </code>\n"
"</pre>".format(entity.language, entity_text))
else:
html.append('<pre><code>{}</code></pre>'
.format(entity_text))
elif entity_type == MessageEntityEmail:
html.append('<a href="mailto:{0}">{0}</a>'.format(entity_text))
elif entity_type == MessageEntityUrl:
html.append('<a href="{0}">{0}</a>'.format(entity_text))
elif entity_type == MessageEntityTextUrl:
html.append('<a href="{}">{}</a>'
.format(escape(entity.url), entity_text))
elif entity_type == MessageEntityMentionName:
html.append('<a href="tg://user?id={}">{}</a>'
.format(entity.user_id, entity_text))
else:
skip_entity = True
last_offset = relative_offset + (0 if skip_entity else length)
while helpers.within_surrogate(text, last_offset, length=_length):
last_offset += 1
html.append(escape(text[last_offset:]))
return _del_surrogate(''.join(html))

197
telethon/_misc/markdown.py Normal file
View File

@@ -0,0 +1,197 @@
"""
Simple markdown parser which does not support nesting. Intended primarily
for use within the library, which attempts to handle emojies correctly,
since they seem to count as two characters and it's a bit strange.
"""
import re
import warnings
from ..helpers import add_surrogate, del_surrogate, within_surrogate, strip_text
from ..tl import TLObject
from ..tl.types import (
MessageEntityBold, MessageEntityItalic, MessageEntityCode,
MessageEntityPre, MessageEntityTextUrl, MessageEntityMentionName,
MessageEntityStrike
)
DEFAULT_DELIMITERS = {
'**': MessageEntityBold,
'__': MessageEntityItalic,
'~~': MessageEntityStrike,
'`': MessageEntityCode,
'```': MessageEntityPre
}
DEFAULT_URL_RE = re.compile(r'\[([\S\s]+?)\]\((.+?)\)')
DEFAULT_URL_FORMAT = '[{0}]({1})'
def overlap(a, b, x, y):
return max(a, x) < min(b, y)
def parse(message, delimiters=None, url_re=None):
"""
Parses the given markdown message and returns its stripped representation
plus a list of the MessageEntity's that were found.
:param message: the message with markdown-like syntax to be parsed.
:param delimiters: the delimiters to be used, {delimiter: type}.
:param url_re: the URL bytes regex to be used. Must have two groups.
:return: a tuple consisting of (clean message, [message entities]).
"""
if not message:
return message, []
if url_re is None:
url_re = DEFAULT_URL_RE
elif isinstance(url_re, str):
url_re = re.compile(url_re)
if not delimiters:
if delimiters is not None:
return message, []
delimiters = DEFAULT_DELIMITERS
# Build a regex to efficiently test all delimiters at once.
# Note that the largest delimiter should go first, we don't
# want ``` to be interpreted as a single back-tick in a code block.
delim_re = re.compile('|'.join('({})'.format(re.escape(k))
for k in sorted(delimiters, key=len, reverse=True)))
# Cannot use a for loop because we need to skip some indices
i = 0
result = []
# Work on byte level with the utf-16le encoding to get the offsets right.
# The offset will just be half the index we're at.
message = add_surrogate(message)
while i < len(message):
m = delim_re.match(message, pos=i)
# Did we find some delimiter here at `i`?
if m:
delim = next(filter(None, m.groups()))
# +1 to avoid matching right after (e.g. "****")
end = message.find(delim, i + len(delim) + 1)
# Did we find the earliest closing tag?
if end != -1:
# Remove the delimiter from the string
message = ''.join((
message[:i],
message[i + len(delim):end],
message[end + len(delim):]
))
# Check other affected entities
for ent in result:
# If the end is after our start, it is affected
if ent.offset + ent.length > i:
# If the old start is also before ours, it is fully enclosed
if ent.offset <= i:
ent.length -= len(delim) * 2
else:
ent.length -= len(delim)
# Append the found entity
ent = delimiters[delim]
if ent == MessageEntityPre:
result.append(ent(i, end - i - len(delim), '')) # has 'lang'
else:
result.append(ent(i, end - i - len(delim)))
# No nested entities inside code blocks
if ent in (MessageEntityCode, MessageEntityPre):
i = end - len(delim)
continue
elif url_re:
m = url_re.match(message, pos=i)
if m:
# Replace the whole match with only the inline URL text.
message = ''.join((
message[:m.start()],
m.group(1),
message[m.end():]
))
delim_size = m.end() - m.start() - len(m.group())
for ent in result:
# If the end is after our start, it is affected
if ent.offset + ent.length > m.start():
ent.length -= delim_size
result.append(MessageEntityTextUrl(
offset=m.start(), length=len(m.group(1)),
url=del_surrogate(m.group(2))
))
i += len(m.group(1))
continue
i += 1
message = strip_text(message, result)
return del_surrogate(message), result
def unparse(text, entities, delimiters=None, url_fmt=None):
"""
Performs the reverse operation to .parse(), effectively returning
markdown-like syntax given a normal text and its MessageEntity's.
:param text: the text to be reconverted into markdown.
:param entities: the MessageEntity's applied to the text.
:return: a markdown-like text representing the combination of both inputs.
"""
if not text or not entities:
return text
if not delimiters:
if delimiters is not None:
return text
delimiters = DEFAULT_DELIMITERS
if url_fmt is not None:
warnings.warn('url_fmt is deprecated') # since it complicates everything *a lot*
if isinstance(entities, TLObject):
entities = (entities,)
text = add_surrogate(text)
delimiters = {v: k for k, v in delimiters.items()}
insert_at = []
for entity in entities:
s = entity.offset
e = entity.offset + entity.length
delimiter = delimiters.get(type(entity), None)
if delimiter:
insert_at.append((s, delimiter))
insert_at.append((e, delimiter))
else:
url = None
if isinstance(entity, MessageEntityTextUrl):
url = entity.url
elif isinstance(entity, MessageEntityMentionName):
url = 'tg://user?id={}'.format(entity.user_id)
if url:
insert_at.append((s, '['))
insert_at.append((e, ']({})'.format(url)))
insert_at.sort(key=lambda t: t[0])
while insert_at:
at, what = insert_at.pop()
# If we are in the middle of a surrogate nudge the position by -1.
# Otherwise we would end up with malformed text and fail to encode.
# For example of bad input: "Hi \ud83d\ude1c"
# https://en.wikipedia.org/wiki/UTF-16#U+010000_to_U+10FFFF
while within_surrogate(text, at):
at += 1
text = text[:at] + what + text[at:]
return del_surrogate(text)

View File

@@ -0,0 +1,111 @@
import asyncio
import collections
import io
import struct
from ..tl import TLRequest
from ..tl.core.messagecontainer import MessageContainer
from ..tl.core.tlmessage import TLMessage
class MessagePacker:
"""
This class packs `RequestState` as outgoing `TLMessages`.
The purpose of this class is to support putting N `RequestState` into a
queue, and then awaiting for "packed" `TLMessage` in the other end. The
simplest case would be ``State -> TLMessage`` (1-to-1 relationship) but
for efficiency purposes it's ``States -> Container`` (N-to-1).
This addresses several needs: outgoing messages will be smaller, so the
encryption and network overhead also is smaller. It's also a central
point where outgoing requests are put, and where ready-messages are get.
"""
def __init__(self, state, loggers):
self._state = state
self._deque = collections.deque()
self._ready = asyncio.Event()
self._log = loggers[__name__]
def append(self, state):
self._deque.append(state)
self._ready.set()
def extend(self, states):
self._deque.extend(states)
self._ready.set()
async def get(self):
"""
Returns (batch, data) if one or more items could be retrieved.
If the cancellation occurs or only invalid items were in the
queue, (None, None) will be returned instead.
"""
if not self._deque:
self._ready.clear()
await self._ready.wait()
buffer = io.BytesIO()
batch = []
size = 0
# Fill a new batch to return while the size is small enough,
# as long as we don't exceed the maximum length of messages.
while self._deque and len(batch) <= MessageContainer.MAXIMUM_LENGTH:
state = self._deque.popleft()
size += len(state.data) + TLMessage.SIZE_OVERHEAD
if size <= MessageContainer.MAXIMUM_SIZE:
state.msg_id = self._state.write_data_as_message(
buffer, state.data, isinstance(state.request, TLRequest),
after_id=state.after.msg_id if state.after else None
)
batch.append(state)
self._log.debug('Assigned msg_id = %d to %s (%x)',
state.msg_id, state.request.__class__.__name__,
id(state.request))
continue
if batch:
# Put the item back since it can't be sent in this batch
self._deque.appendleft(state)
break
# If a single message exceeds the maximum size, then the
# message payload cannot be sent. Telegram would forcibly
# close the connection; message would never be confirmed.
#
# We don't put the item back because it can never be sent.
# If we did, we would loop again and reach this same path.
# Setting the exception twice results in `InvalidStateError`
# and this method should never return with error, which we
# really want to avoid.
self._log.warning(
'Message payload for %s is too long (%d) and cannot be sent',
state.request.__class__.__name__, len(state.data)
)
state.future.set_exception(
ValueError('Request payload is too big'))
size = 0
continue
if not batch:
return None, None
if len(batch) > 1:
# Inlined code to pack several messages into a container
data = struct.pack(
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(batch)
) + buffer.getvalue()
buffer = io.BytesIO()
container_id = self._state.write_data_as_message(
buffer, data, content_related=False
)
for s in batch:
s.container_id = container_id
data = buffer.getvalue()
return batch, data

194
telethon/_misc/password.py Normal file
View File

@@ -0,0 +1,194 @@
import hashlib
import os
from .crypto import factorization
from .tl import types
def check_prime_and_good_check(prime: int, g: int):
good_prime_bits_count = 2048
if prime < 0 or prime.bit_length() != good_prime_bits_count:
raise ValueError('bad prime count {}, expected {}'
.format(prime.bit_length(), good_prime_bits_count))
# TODO This is awfully slow
if factorization.Factorization.factorize(prime)[0] != 1:
raise ValueError('given "prime" is not prime')
if g == 2:
if prime % 8 != 7:
raise ValueError('bad g {}, mod8 {}'.format(g, prime % 8))
elif g == 3:
if prime % 3 != 2:
raise ValueError('bad g {}, mod3 {}'.format(g, prime % 3))
elif g == 4:
pass
elif g == 5:
if prime % 5 not in (1, 4):
raise ValueError('bad g {}, mod5 {}'.format(g, prime % 5))
elif g == 6:
if prime % 24 not in (19, 23):
raise ValueError('bad g {}, mod24 {}'.format(g, prime % 24))
elif g == 7:
if prime % 7 not in (3, 5, 6):
raise ValueError('bad g {}, mod7 {}'.format(g, prime % 7))
else:
raise ValueError('bad g {}'.format(g))
prime_sub1_div2 = (prime - 1) // 2
if factorization.Factorization.factorize(prime_sub1_div2)[0] != 1:
raise ValueError('(prime - 1) // 2 is not prime')
# Else it's good
def check_prime_and_good(prime_bytes: bytes, g: int):
good_prime = bytes((
0xC7, 0x1C, 0xAE, 0xB9, 0xC6, 0xB1, 0xC9, 0x04, 0x8E, 0x6C, 0x52, 0x2F, 0x70, 0xF1, 0x3F, 0x73,
0x98, 0x0D, 0x40, 0x23, 0x8E, 0x3E, 0x21, 0xC1, 0x49, 0x34, 0xD0, 0x37, 0x56, 0x3D, 0x93, 0x0F,
0x48, 0x19, 0x8A, 0x0A, 0xA7, 0xC1, 0x40, 0x58, 0x22, 0x94, 0x93, 0xD2, 0x25, 0x30, 0xF4, 0xDB,
0xFA, 0x33, 0x6F, 0x6E, 0x0A, 0xC9, 0x25, 0x13, 0x95, 0x43, 0xAE, 0xD4, 0x4C, 0xCE, 0x7C, 0x37,
0x20, 0xFD, 0x51, 0xF6, 0x94, 0x58, 0x70, 0x5A, 0xC6, 0x8C, 0xD4, 0xFE, 0x6B, 0x6B, 0x13, 0xAB,
0xDC, 0x97, 0x46, 0x51, 0x29, 0x69, 0x32, 0x84, 0x54, 0xF1, 0x8F, 0xAF, 0x8C, 0x59, 0x5F, 0x64,
0x24, 0x77, 0xFE, 0x96, 0xBB, 0x2A, 0x94, 0x1D, 0x5B, 0xCD, 0x1D, 0x4A, 0xC8, 0xCC, 0x49, 0x88,
0x07, 0x08, 0xFA, 0x9B, 0x37, 0x8E, 0x3C, 0x4F, 0x3A, 0x90, 0x60, 0xBE, 0xE6, 0x7C, 0xF9, 0xA4,
0xA4, 0xA6, 0x95, 0x81, 0x10, 0x51, 0x90, 0x7E, 0x16, 0x27, 0x53, 0xB5, 0x6B, 0x0F, 0x6B, 0x41,
0x0D, 0xBA, 0x74, 0xD8, 0xA8, 0x4B, 0x2A, 0x14, 0xB3, 0x14, 0x4E, 0x0E, 0xF1, 0x28, 0x47, 0x54,
0xFD, 0x17, 0xED, 0x95, 0x0D, 0x59, 0x65, 0xB4, 0xB9, 0xDD, 0x46, 0x58, 0x2D, 0xB1, 0x17, 0x8D,
0x16, 0x9C, 0x6B, 0xC4, 0x65, 0xB0, 0xD6, 0xFF, 0x9C, 0xA3, 0x92, 0x8F, 0xEF, 0x5B, 0x9A, 0xE4,
0xE4, 0x18, 0xFC, 0x15, 0xE8, 0x3E, 0xBE, 0xA0, 0xF8, 0x7F, 0xA9, 0xFF, 0x5E, 0xED, 0x70, 0x05,
0x0D, 0xED, 0x28, 0x49, 0xF4, 0x7B, 0xF9, 0x59, 0xD9, 0x56, 0x85, 0x0C, 0xE9, 0x29, 0x85, 0x1F,
0x0D, 0x81, 0x15, 0xF6, 0x35, 0xB1, 0x05, 0xEE, 0x2E, 0x4E, 0x15, 0xD0, 0x4B, 0x24, 0x54, 0xBF,
0x6F, 0x4F, 0xAD, 0xF0, 0x34, 0xB1, 0x04, 0x03, 0x11, 0x9C, 0xD8, 0xE3, 0xB9, 0x2F, 0xCC, 0x5B))
if good_prime == prime_bytes:
if g in (3, 4, 5, 7):
return # It's good
check_prime_and_good_check(int.from_bytes(prime_bytes, 'big'), g)
def is_good_large(number: int, p: int) -> bool:
return number > 0 and p - number > 0
SIZE_FOR_HASH = 256
def num_bytes_for_hash(number: bytes) -> bytes:
return bytes(SIZE_FOR_HASH - len(number)) + number
def big_num_for_hash(g: int) -> bytes:
return g.to_bytes(SIZE_FOR_HASH, 'big')
def sha256(*p: bytes) -> bytes:
hash = hashlib.sha256()
for q in p:
hash.update(q)
return hash.digest()
def is_good_mod_exp_first(modexp, prime) -> bool:
diff = prime - modexp
min_diff_bits_count = 2048 - 64
max_mod_exp_size = 256
if diff < 0 or \
diff.bit_length() < min_diff_bits_count or \
modexp.bit_length() < min_diff_bits_count or \
(modexp.bit_length() + 7) // 8 > max_mod_exp_size:
return False
return True
def xor(a: bytes, b: bytes) -> bytes:
return bytes(x ^ y for x, y in zip(a, b))
def pbkdf2sha512(password: bytes, salt: bytes, iterations: int):
return hashlib.pbkdf2_hmac('sha512', password, salt, iterations)
def compute_hash(algo: types.PasswordKdfAlgoSHA256SHA256PBKDF2HMACSHA512iter100000SHA256ModPow,
password: str):
hash1 = sha256(algo.salt1, password.encode('utf-8'), algo.salt1)
hash2 = sha256(algo.salt2, hash1, algo.salt2)
hash3 = pbkdf2sha512(hash2, algo.salt1, 100000)
return sha256(algo.salt2, hash3, algo.salt2)
def compute_digest(algo: types.PasswordKdfAlgoSHA256SHA256PBKDF2HMACSHA512iter100000SHA256ModPow,
password: str):
try:
check_prime_and_good(algo.p, algo.g)
except ValueError:
raise ValueError('bad p/g in password')
value = pow(algo.g,
int.from_bytes(compute_hash(algo, password), 'big'),
int.from_bytes(algo.p, 'big'))
return big_num_for_hash(value)
# https://github.com/telegramdesktop/tdesktop/blob/18b74b90451a7db2379a9d753c9cbaf8734b4d5d/Telegram/SourceFiles/core/core_cloud_password.cpp
def compute_check(request: types.account.Password, password: str):
algo = request.current_algo
if not isinstance(algo, types.PasswordKdfAlgoSHA256SHA256PBKDF2HMACSHA512iter100000SHA256ModPow):
raise ValueError('unsupported password algorithm {}'
.format(algo.__class__.__name__))
pw_hash = compute_hash(algo, password)
p = int.from_bytes(algo.p, 'big')
g = algo.g
B = int.from_bytes(request.srp_B, 'big')
try:
check_prime_and_good(algo.p, g)
except ValueError:
raise ValueError('bad p/g in password')
if not is_good_large(B, p):
raise ValueError('bad b in check')
x = int.from_bytes(pw_hash, 'big')
p_for_hash = num_bytes_for_hash(algo.p)
g_for_hash = big_num_for_hash(g)
b_for_hash = num_bytes_for_hash(request.srp_B)
g_x = pow(g, x, p)
k = int.from_bytes(sha256(p_for_hash, g_for_hash), 'big')
kg_x = (k * g_x) % p
def generate_and_check_random():
random_size = 256
while True:
random = os.urandom(random_size)
a = int.from_bytes(random, 'big')
A = pow(g, a, p)
if is_good_mod_exp_first(A, p):
a_for_hash = big_num_for_hash(A)
u = int.from_bytes(sha256(a_for_hash, b_for_hash), 'big')
if u > 0:
return (a, a_for_hash, u)
a, a_for_hash, u = generate_and_check_random()
g_b = (B - kg_x) % p
if not is_good_mod_exp_first(g_b, p):
raise ValueError('bad g_b')
ux = u * x
a_ux = a + ux
S = pow(g_b, a_ux, p)
K = sha256(big_num_for_hash(S))
M1 = sha256(
xor(sha256(p_for_hash), sha256(g_for_hash)),
sha256(algo.salt1),
sha256(algo.salt2),
a_for_hash,
b_for_hash,
K
)
return types.InputCheckPasswordSRP(
request.srp_id, bytes(a_for_hash), bytes(M1))

View File

@@ -0,0 +1,116 @@
import abc
import asyncio
import time
from . import helpers
class RequestIter(abc.ABC):
"""
Helper class to deal with requests that need offsets to iterate.
It has some facilities, such as automatically sleeping a desired
amount of time between requests if needed (but not more).
`limit` is the total amount of items that the iterator should return.
This is handled on this base class, and will be always ``>= 0``.
`left` will be reset every time the iterator is used and will indicate
the amount of items that should be emitted left, so that subclasses can
be more efficient and fetch only as many items as they need.
Iterators may be used with ``reversed``, and their `reverse` flag will
be set to `True` if that's the case. Note that if this flag is set,
`buffer` should be filled in reverse too.
"""
def __init__(self, client, limit, *, reverse=False, wait_time=None, **kwargs):
self.client = client
self.reverse = reverse
self.wait_time = wait_time
self.kwargs = kwargs
self.limit = max(float('inf') if limit is None else limit, 0)
self.left = self.limit
self.buffer = None
self.index = 0
self.total = None
self.last_load = 0
async def _init(self, **kwargs):
"""
Called when asynchronous initialization is necessary. All keyword
arguments passed to `__init__` will be forwarded here, and it's
preferable to use named arguments in the subclasses without defaults
to avoid forgetting or misspelling any of them.
This method may ``raise StopAsyncIteration`` if it cannot continue.
This method may actually fill the initial buffer if it needs to,
and similarly to `_load_next_chunk`, ``return True`` to indicate
that this is the last iteration (just the initial load).
"""
async def __anext__(self):
if self.buffer is None:
self.buffer = []
if await self._init(**self.kwargs):
self.left = len(self.buffer)
if self.left <= 0: # <= 0 because subclasses may change it
raise StopAsyncIteration
if self.index == len(self.buffer):
# asyncio will handle times <= 0 to sleep 0 seconds
if self.wait_time:
await asyncio.sleep(
self.wait_time - (time.time() - self.last_load)
)
self.last_load = time.time()
self.index = 0
self.buffer = []
if await self._load_next_chunk():
self.left = len(self.buffer)
if not self.buffer:
raise StopAsyncIteration
result = self.buffer[self.index]
self.left -= 1
self.index += 1
return result
def __aiter__(self):
self.buffer = None
self.index = 0
self.last_load = 0
self.left = self.limit
return self
async def collect(self):
"""
Create a `self` iterator and collect it into a `TotalList`
(a normal list with a `.total` attribute).
"""
result = helpers.TotalList()
async for message in self:
result.append(message)
result.total = self.total
return result
@abc.abstractmethod
async def _load_next_chunk(self):
"""
Called when the next chunk is necessary.
It should extend the `buffer` with new items.
It should return `True` if it's the last chunk,
after which moment the method won't be called again
during the same iteration.
"""
raise NotImplementedError
def __reversed__(self):
self.reverse = not self.reverse
return self # __aiter__ will be called after, too

View File

@@ -0,0 +1,164 @@
import inspect
from .tl import types
# Which updates have the following fields?
_has_channel_id = []
# TODO EntityCache does the same. Reuse?
def _fill():
for name in dir(types):
update = getattr(types, name)
if getattr(update, 'SUBCLASS_OF_ID', None) == 0x9f89304e:
cid = update.CONSTRUCTOR_ID
sig = inspect.signature(update.__init__)
for param in sig.parameters.values():
if param.name == 'channel_id' and param.annotation == int:
_has_channel_id.append(cid)
if not _has_channel_id:
raise RuntimeError('FIXME: Did the init signature or updates change?')
# We use a function to avoid cluttering the globals (with name/update/cid/doc)
_fill()
class StateCache:
"""
In-memory update state cache, defaultdict-like behaviour.
"""
def __init__(self, initial, loggers):
# We only care about the pts and the date. By using a tuple which
# is lightweight and immutable we can easily copy them around to
# each update in case they need to fetch missing entities.
self._logger = loggers[__name__]
if initial:
self._pts_date = initial.pts, initial.date
else:
self._pts_date = None, None
def reset(self):
self.__dict__.clear()
self._pts_date = None, None
# TODO Call this when receiving responses too...?
def update(
self,
update,
*,
channel_id=None,
has_pts=frozenset(x.CONSTRUCTOR_ID for x in (
types.UpdateNewMessage,
types.UpdateDeleteMessages,
types.UpdateReadHistoryInbox,
types.UpdateReadHistoryOutbox,
types.UpdateWebPage,
types.UpdateReadMessagesContents,
types.UpdateEditMessage,
types.updates.State,
types.updates.DifferenceTooLong,
types.UpdateShortMessage,
types.UpdateShortChatMessage,
types.UpdateShortSentMessage
)),
has_date=frozenset(x.CONSTRUCTOR_ID for x in (
types.UpdateUserPhoto,
types.UpdateEncryption,
types.UpdateEncryptedMessagesRead,
types.UpdateChatParticipantAdd,
types.updates.DifferenceEmpty,
types.UpdateShortMessage,
types.UpdateShortChatMessage,
types.UpdateShort,
types.UpdatesCombined,
types.Updates,
types.UpdateShortSentMessage,
)),
has_channel_pts=frozenset(x.CONSTRUCTOR_ID for x in (
types.UpdateChannelTooLong,
types.UpdateNewChannelMessage,
types.UpdateDeleteChannelMessages,
types.UpdateEditChannelMessage,
types.UpdateChannelWebPage,
types.updates.ChannelDifferenceEmpty,
types.updates.ChannelDifferenceTooLong,
types.updates.ChannelDifference
)),
check_only=False
):
"""
Update the state with the given update.
"""
cid = update.CONSTRUCTOR_ID
if check_only:
return cid in has_pts or cid in has_date or cid in has_channel_pts
if cid in has_pts:
if cid in has_date:
self._pts_date = update.pts, update.date
else:
self._pts_date = update.pts, self._pts_date[1]
elif cid in has_date:
self._pts_date = self._pts_date[0], update.date
if cid in has_channel_pts:
if channel_id is None:
channel_id = self.get_channel_id(update)
if channel_id is None:
self._logger.info(
'Failed to retrieve channel_id from %s', update)
else:
self.__dict__[channel_id] = update.pts
def get_channel_id(
self,
update,
has_channel_id=frozenset(_has_channel_id),
# Hardcoded because only some with message are for channels
has_message=frozenset(x.CONSTRUCTOR_ID for x in (
types.UpdateNewChannelMessage,
types.UpdateEditChannelMessage
))
):
"""
Gets the **unmarked** channel ID from this update, if it has any.
Fails for ``*difference`` updates, where ``channel_id``
is supposedly already known from the outside.
"""
cid = update.CONSTRUCTOR_ID
if cid in has_channel_id:
return update.channel_id
elif cid in has_message:
if update.message.peer_id is None:
# Telegram sometimes sends empty messages to give a newer pts:
# UpdateNewChannelMessage(message=MessageEmpty(id), pts=pts, pts_count=1)
# Not sure why, but it's safe to ignore them.
self._logger.debug('Update has None peer_id %s', update)
else:
return update.message.peer_id.channel_id
return None
def __getitem__(self, item):
"""
If `item` is `None`, returns the default ``(pts, date)``.
If it's an **unmarked** channel ID, returns its ``pts``.
If no information is known, ``pts`` will be `None`.
"""
if item is None:
return self._pts_date
else:
return self.__dict__.get(item)
def __setitem__(self, where, value):
if where is None:
self._pts_date = value
else:
self.__dict__[where] = value

1559
telethon/_misc/utils.py Normal file

File diff suppressed because it is too large Load Diff