mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-08-09 13:29:47 +00:00
Correct privacy on sessions module
This commit is contained in:
4
telethon/_sessions/__init__.py
Normal file
4
telethon/_sessions/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .abstract import Session
|
||||
from .memory import MemorySession
|
||||
from .sqlite import SQLiteSession
|
||||
from .string import StringSession
|
92
telethon/_sessions/abstract.py
Normal file
92
telethon/_sessions/abstract.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from .types import DataCenter, ChannelState, SessionState, Entity
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class Session(ABC):
|
||||
@abstractmethod
|
||||
async def insert_dc(self, dc: DataCenter):
|
||||
"""
|
||||
Store a new or update an existing `DataCenter` with matching ``id``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_dc(self) -> List[DataCenter]:
|
||||
"""
|
||||
Get a list of all currently-stored `DataCenter`. Should not contain duplicate ``id``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def set_state(self, state: SessionState):
|
||||
"""
|
||||
Set the state about the current session.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_state(self) -> Optional[SessionState]:
|
||||
"""
|
||||
Get the state about the current session.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def insert_channel_state(self, state: ChannelState):
|
||||
"""
|
||||
Store a new or update an existing `ChannelState` with matching ``id``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_channel_states(self) -> List[ChannelState]:
|
||||
"""
|
||||
Get a list of all currently-stored `ChannelState`. Should not contain duplicate ``id``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def insert_entities(self, entities: List[Entity]):
|
||||
"""
|
||||
Store new or update existing `Entity` with matching ``id``.
|
||||
|
||||
Entities should be saved on a best-effort. It is okay to not save them, although the
|
||||
library may need to do extra work if a previously-saved entity is missing, or even be
|
||||
unable to continue without the entity.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
|
||||
"""
|
||||
Get the `Entity` with matching ``ty`` and ``id``.
|
||||
|
||||
The following groups of ``ty`` should be treated to be equivalent, that is, for a given
|
||||
``ty`` and ``id``, if the ``ty`` is in a given group, a matching ``access_hash`` with
|
||||
that ``id`` from within any ``ty`` in that group should be returned.
|
||||
|
||||
* ``'U'`` and ``'B'`` (user and bot).
|
||||
* ``'G'`` (small group chat).
|
||||
* ``'C'``, ``'M'`` and ``'E'`` (broadcast channel, megagroup channel, and gigagroup channel).
|
||||
|
||||
For example, if a ``ty`` representing a bot is stored but the asking ``ty`` is a user,
|
||||
the corresponding ``access_hash`` should still be returned.
|
||||
|
||||
You may use `types.canonical_entity_type` to find out the canonical type.
|
||||
|
||||
A ``ty`` with the value of ``None`` should be treated as "any entity with matching ID".
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def save(self):
|
||||
"""
|
||||
Save the session.
|
||||
|
||||
May do nothing if the other methods already saved when they were called.
|
||||
|
||||
May return custom data when manual saving is intended.
|
||||
"""
|
||||
raise NotImplementedError
|
47
telethon/_sessions/memory.py
Normal file
47
telethon/_sessions/memory.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from .types import DataCenter, ChannelState, SessionState, Entity
|
||||
from .abstract import Session
|
||||
from .._misc import utils, tlobject
|
||||
from .. import _tl
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class MemorySession(Session):
|
||||
__slots__ = ('dcs', 'state', 'channel_states', 'entities')
|
||||
|
||||
def __init__(self):
|
||||
self.dcs = {}
|
||||
self.state = None
|
||||
self.channel_states = {}
|
||||
self.entities = {}
|
||||
|
||||
async def insert_dc(self, dc: DataCenter):
|
||||
self.dcs[dc.id] = dc
|
||||
|
||||
async def get_all_dc(self) -> List[DataCenter]:
|
||||
return list(self.dcs.values())
|
||||
|
||||
async def set_state(self, state: SessionState):
|
||||
self.state = state
|
||||
|
||||
async def get_state(self) -> Optional[SessionState]:
|
||||
return self.state
|
||||
|
||||
async def insert_channel_state(self, state: ChannelState):
|
||||
self.channel_states[state.channel_id] = state
|
||||
|
||||
async def get_all_channel_states(self) -> List[ChannelState]:
|
||||
return list(self.channel_states.values())
|
||||
|
||||
async def insert_entities(self, entities: List[Entity]):
|
||||
self.entities.update((e.id, (e.ty, e.access_hash)) for e in entities)
|
||||
|
||||
async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
|
||||
try:
|
||||
ty, access_hash = self.entities[id]
|
||||
return Entity(ty, id, access_hash)
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
async def save(self):
|
||||
pass
|
284
telethon/_sessions/sqlite.py
Normal file
284
telethon/_sessions/sqlite.py
Normal file
@@ -0,0 +1,284 @@
|
||||
import datetime
|
||||
import os
|
||||
import time
|
||||
import ipaddress
|
||||
from typing import Optional, List
|
||||
|
||||
from .abstract import Session
|
||||
from .._misc import utils
|
||||
from .. import _tl
|
||||
from .types import DataCenter, ChannelState, SessionState, Entity
|
||||
|
||||
try:
|
||||
import sqlite3
|
||||
sqlite3_err = None
|
||||
except ImportError as e:
|
||||
sqlite3 = None
|
||||
sqlite3_err = type(e)
|
||||
|
||||
EXTENSION = '.session'
|
||||
CURRENT_VERSION = 8 # database version
|
||||
|
||||
|
||||
class SQLiteSession(Session):
|
||||
"""
|
||||
This session contains the required information to login into your
|
||||
Telegram account. NEVER give the saved session file to anyone, since
|
||||
they would gain instant access to all your messages and contacts.
|
||||
|
||||
If you think the session has been compromised, close all the sessions
|
||||
through an official Telegram client to revoke the authorization.
|
||||
"""
|
||||
|
||||
def __init__(self, session_id=None):
|
||||
if sqlite3 is None:
|
||||
raise sqlite3_err
|
||||
|
||||
super().__init__()
|
||||
self.filename = ':memory:'
|
||||
self.save_entities = True
|
||||
|
||||
if session_id:
|
||||
self.filename = session_id
|
||||
if not self.filename.endswith(EXTENSION):
|
||||
self.filename += EXTENSION
|
||||
|
||||
self._conn = None
|
||||
c = self._cursor()
|
||||
c.execute("select name from sqlite_master "
|
||||
"where type='table' and name='version'")
|
||||
if c.fetchone():
|
||||
# Tables already exist, check for the version
|
||||
c.execute("select version from version")
|
||||
version = c.fetchone()[0]
|
||||
if version < CURRENT_VERSION:
|
||||
self._upgrade_database(old=version)
|
||||
c.execute("delete from version")
|
||||
c.execute("insert into version values (?)", (CURRENT_VERSION,))
|
||||
self._conn.commit()
|
||||
else:
|
||||
# Tables don't exist, create new ones
|
||||
self._create_table(c, 'version (version integer primary key)')
|
||||
self._mk_tables(c)
|
||||
c.execute("insert into version values (?)", (CURRENT_VERSION,))
|
||||
self._conn.commit()
|
||||
|
||||
# Must have committed or else the version will not have been updated while new tables
|
||||
# exist, leading to a half-upgraded state.
|
||||
c.close()
|
||||
|
||||
def _upgrade_database(self, old):
|
||||
c = self._cursor()
|
||||
if old == 1:
|
||||
old += 1
|
||||
# old == 1 doesn't have the old sent_files so no need to drop
|
||||
if old == 2:
|
||||
old += 1
|
||||
# Old cache from old sent_files lasts then a day anyway, drop
|
||||
c.execute('drop table sent_files')
|
||||
self._create_table(c, """sent_files (
|
||||
md5_digest blob,
|
||||
file_size integer,
|
||||
type integer,
|
||||
id integer,
|
||||
hash integer,
|
||||
primary key(md5_digest, file_size, type)
|
||||
)""")
|
||||
if old == 3:
|
||||
old += 1
|
||||
self._create_table(c, """update_state (
|
||||
id integer primary key,
|
||||
pts integer,
|
||||
qts integer,
|
||||
date integer,
|
||||
seq integer
|
||||
)""")
|
||||
if old == 4:
|
||||
old += 1
|
||||
c.execute("alter table sessions add column takeout_id integer")
|
||||
if old == 5:
|
||||
# Not really any schema upgrade, but potentially all access
|
||||
# hashes for User and Channel are wrong, so drop them off.
|
||||
old += 1
|
||||
c.execute('delete from entities')
|
||||
if old == 6:
|
||||
old += 1
|
||||
c.execute("alter table entities add column date integer")
|
||||
if old == 7:
|
||||
self._mk_tables(c)
|
||||
c.execute('''
|
||||
insert into datacenter (id, ipv4, ipv6, port, auth)
|
||||
select dc_id, server_address, server_address, port, auth_key
|
||||
from sessions
|
||||
''')
|
||||
c.execute('''
|
||||
insert into session (user_id, dc_id, bot, pts, qts, date, seq, takeout_id)
|
||||
select
|
||||
0,
|
||||
s.dc_id,
|
||||
0,
|
||||
coalesce(u.pts, 0),
|
||||
coalesce(u.qts, 0),
|
||||
coalesce(u.date, 0),
|
||||
coalesce(u.seq, 0),
|
||||
s.takeout_id
|
||||
from sessions s
|
||||
left join update_state u on u.id = 0
|
||||
limit 1
|
||||
''')
|
||||
c.execute('''
|
||||
insert into entity (id, access_hash, ty)
|
||||
select
|
||||
case
|
||||
when id < -1000000000000 then -(id + 1000000000000)
|
||||
when id < 0 then -id
|
||||
else id
|
||||
end,
|
||||
hash,
|
||||
case
|
||||
when id < -1000000000000 then 67
|
||||
when id < 0 then 71
|
||||
else 85
|
||||
end
|
||||
from entities
|
||||
''')
|
||||
c.execute('drop table sessions')
|
||||
c.execute('drop table entities')
|
||||
c.execute('drop table sent_files')
|
||||
c.execute('drop table update_state')
|
||||
|
||||
def _mk_tables(self, c):
|
||||
self._create_table(
|
||||
c,
|
||||
'''datacenter (
|
||||
id integer primary key,
|
||||
ipv4 text not null,
|
||||
ipv6 text,
|
||||
port integer not null,
|
||||
auth blob not null
|
||||
)''',
|
||||
'''session (
|
||||
user_id integer primary key,
|
||||
dc_id integer not null,
|
||||
bot integer not null,
|
||||
pts integer not null,
|
||||
qts integer not null,
|
||||
date integer not null,
|
||||
seq integer not null,
|
||||
takeout_id integer
|
||||
)''',
|
||||
'''channel (
|
||||
channel_id integer primary key,
|
||||
pts integer not null
|
||||
)''',
|
||||
'''entity (
|
||||
id integer primary key,
|
||||
access_hash integer not null,
|
||||
ty integer not null
|
||||
)''',
|
||||
)
|
||||
|
||||
async def insert_dc(self, dc: DataCenter):
|
||||
self._execute(
|
||||
'insert or replace into datacenter values (?,?,?,?,?)',
|
||||
dc.id,
|
||||
str(ipaddress.ip_address(dc.ipv4)),
|
||||
str(ipaddress.ip_address(dc.ipv6)) if dc.ipv6 else None,
|
||||
dc.port,
|
||||
dc.auth
|
||||
)
|
||||
|
||||
async def get_all_dc(self) -> List[DataCenter]:
|
||||
c = self._cursor()
|
||||
res = []
|
||||
for (id, ipv4, ipv6, port, auth) in c.execute('select * from datacenter'):
|
||||
res.append(DataCenter(
|
||||
id=id,
|
||||
ipv4=int(ipaddress.ip_address(ipv4)),
|
||||
ipv6=int(ipaddress.ip_address(ipv6)) if ipv6 else None,
|
||||
port=port,
|
||||
auth=auth,
|
||||
))
|
||||
return res
|
||||
|
||||
async def set_state(self, state: SessionState):
|
||||
c = self._cursor()
|
||||
try:
|
||||
self._execute('delete from session')
|
||||
self._execute(
|
||||
'insert into session values (?,?,?,?,?,?,?,?)',
|
||||
state.user_id,
|
||||
state.dc_id,
|
||||
int(state.bot),
|
||||
state.pts,
|
||||
state.qts,
|
||||
state.date,
|
||||
state.seq,
|
||||
state.takeout_id,
|
||||
)
|
||||
finally:
|
||||
c.close()
|
||||
|
||||
async def get_state(self) -> Optional[SessionState]:
|
||||
row = self._execute('select * from session')
|
||||
return SessionState(*row) if row else None
|
||||
|
||||
async def insert_channel_state(self, state: ChannelState):
|
||||
self._execute(
|
||||
'insert or replace into channel values (?,?)',
|
||||
state.channel_id,
|
||||
state.pts,
|
||||
)
|
||||
|
||||
async def get_all_channel_states(self) -> List[ChannelState]:
|
||||
c = self._cursor()
|
||||
try:
|
||||
return [
|
||||
ChannelState(*row)
|
||||
for row in c.execute('select * from channel')
|
||||
]
|
||||
finally:
|
||||
c.close()
|
||||
|
||||
async def insert_entities(self, entities: List[Entity]):
|
||||
c = self._cursor()
|
||||
try:
|
||||
c.executemany(
|
||||
'insert or replace into entity values (?,?,?)',
|
||||
[(e.id, e.access_hash, e.ty) for e in entities]
|
||||
)
|
||||
finally:
|
||||
c.close()
|
||||
|
||||
async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
|
||||
row = self._execute('select ty, id, access_hash from entity where id = ?', id)
|
||||
return Entity(*row) if row else None
|
||||
|
||||
async def save(self):
|
||||
# This is a no-op if there are no changes to commit, so there's
|
||||
# no need for us to keep track of an "unsaved changes" variable.
|
||||
if self._conn is not None:
|
||||
self._conn.commit()
|
||||
|
||||
@staticmethod
|
||||
def _create_table(c, *definitions):
|
||||
for definition in definitions:
|
||||
c.execute('create table {}'.format(definition))
|
||||
|
||||
def _cursor(self):
|
||||
"""Asserts that the connection is open and returns a cursor"""
|
||||
if self._conn is None:
|
||||
self._conn = sqlite3.connect(self.filename,
|
||||
check_same_thread=False)
|
||||
return self._conn.cursor()
|
||||
|
||||
def _execute(self, stmt, *values):
|
||||
"""
|
||||
Gets a cursor, executes `stmt` and closes the cursor,
|
||||
fetching one row afterwards and returning its result.
|
||||
"""
|
||||
c = self._cursor()
|
||||
try:
|
||||
return c.execute(stmt, values).fetchone()
|
||||
finally:
|
||||
c.close()
|
88
telethon/_sessions/string.py
Normal file
88
telethon/_sessions/string.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import base64
|
||||
import ipaddress
|
||||
import struct
|
||||
|
||||
from .abstract import Session
|
||||
from .memory import MemorySession
|
||||
from .types import DataCenter, ChannelState, SessionState, Entity
|
||||
|
||||
_STRUCT_PREFORMAT = '>B{}sH256s'
|
||||
|
||||
CURRENT_VERSION = '1'
|
||||
|
||||
|
||||
class StringSession(MemorySession):
|
||||
"""
|
||||
This session file can be easily saved and loaded as a string. According
|
||||
to the initial design, it contains only the data that is necessary for
|
||||
successful connection and authentication, so takeout ID is not stored.
|
||||
|
||||
It is thought to be used where you don't want to create any on-disk
|
||||
files but would still like to be able to save and load existing sessions
|
||||
by other means.
|
||||
|
||||
You can use custom `encode` and `decode` functions, if present:
|
||||
|
||||
* `encode` definition must be ``def encode(value: bytes) -> str:``.
|
||||
* `decode` definition must be ``def decode(value: str) -> bytes:``.
|
||||
"""
|
||||
def __init__(self, string: str = None):
|
||||
super().__init__()
|
||||
if string:
|
||||
if string[0] != CURRENT_VERSION:
|
||||
raise ValueError('Not a valid string')
|
||||
|
||||
string = string[1:]
|
||||
ip_len = 4 if len(string) == 352 else 16
|
||||
dc_id, ip, port, key = struct.unpack(
|
||||
_STRUCT_PREFORMAT.format(ip_len), StringSession.decode(string))
|
||||
|
||||
self.state = SessionState(
|
||||
dc_id=dc_id,
|
||||
user_id=0,
|
||||
bot=False,
|
||||
pts=0,
|
||||
qts=0,
|
||||
date=0,
|
||||
seq=0,
|
||||
takeout_id=0
|
||||
)
|
||||
if ip_len == 4:
|
||||
ipv4 = int.from_bytes(ip, 'big', False)
|
||||
ipv6 = None
|
||||
else:
|
||||
ipv4 = None
|
||||
ipv6 = int.from_bytes(ip, 'big', signed=False)
|
||||
|
||||
self.dcs[dc_id] = DataCenter(
|
||||
id=dc_id,
|
||||
ipv4=ipv4,
|
||||
ipv6=ipv6,
|
||||
port=port,
|
||||
auth=key
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def encode(x: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(x).decode('ascii')
|
||||
|
||||
@staticmethod
|
||||
def decode(x: str) -> bytes:
|
||||
return base64.urlsafe_b64decode(x)
|
||||
|
||||
def save(self: Session):
|
||||
if not self.state:
|
||||
return ''
|
||||
|
||||
if self.state.ipv6 is not None:
|
||||
ip = self.state.ipv6.to_bytes(16, 'big', signed=False)
|
||||
else:
|
||||
ip = self.state.ipv6.to_bytes(4, 'big', signed=False)
|
||||
|
||||
return CURRENT_VERSION + StringSession.encode(struct.pack(
|
||||
_STRUCT_PREFORMAT.format(len(ip)),
|
||||
self.state.dc_id,
|
||||
ip,
|
||||
self.state.port,
|
||||
self.dcs[self.state.dc_id].auth
|
||||
))
|
157
telethon/_sessions/types.py
Normal file
157
telethon/_sessions/types.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class DataCenter:
|
||||
"""
|
||||
Stores the information needed to connect to a datacenter.
|
||||
|
||||
* id: 32-bit number representing the datacenter identifier as given by Telegram.
|
||||
* ipv4 and ipv6: 32-bit or 128-bit number storing the IP address of the datacenter.
|
||||
* port: 16-bit number storing the port number needed to connect to the datacenter.
|
||||
* bytes: arbitrary binary payload needed to authenticate to the datacenter.
|
||||
"""
|
||||
__slots__ = ('id', 'ipv4', 'ipv6', 'port', 'auth')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
ipv4: int,
|
||||
ipv6: Optional[int],
|
||||
port: int,
|
||||
auth: bytes
|
||||
):
|
||||
self.id = id
|
||||
self.ipv4 = ipv4
|
||||
self.ipv6 = ipv6
|
||||
self.port = port
|
||||
self.auth = auth
|
||||
|
||||
|
||||
class SessionState:
|
||||
"""
|
||||
Stores the information needed to fetch updates and about the current user.
|
||||
|
||||
* user_id: 64-bit number representing the user identifier.
|
||||
* dc_id: 32-bit number relating to the datacenter identifier where the user is.
|
||||
* bot: is the logged-in user a bot?
|
||||
* pts: 64-bit number holding the state needed to fetch updates.
|
||||
* qts: alternative 64-bit number holding the state needed to fetch updates.
|
||||
* date: 64-bit number holding the date needed to fetch updates.
|
||||
* seq: 64-bit-number holding the sequence number needed to fetch updates.
|
||||
* takeout_id: 64-bit-number holding the identifier of the current takeout session.
|
||||
|
||||
Note that some of the numbers will only use 32 out of the 64 available bits.
|
||||
However, for future-proofing reasons, we recommend you pretend they are 64-bit long.
|
||||
"""
|
||||
__slots__ = ('user_id', 'dc_id', 'bot', 'pts', 'qts', 'date', 'seq', 'takeout_id')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: int,
|
||||
dc_id: int,
|
||||
bot: bool,
|
||||
pts: int,
|
||||
qts: int,
|
||||
date: int,
|
||||
seq: int,
|
||||
takeout_id: Optional[int],
|
||||
):
|
||||
self.user_id = user_id
|
||||
self.dc_id = dc_id
|
||||
self.bot = bot
|
||||
self.pts = pts
|
||||
self.qts = qts
|
||||
self.date = date
|
||||
self.seq = seq
|
||||
self.takeout_id = takeout_id
|
||||
|
||||
|
||||
class ChannelState:
|
||||
"""
|
||||
Stores the information needed to fetch updates from a channel.
|
||||
|
||||
* channel_id: 64-bit number representing the channel identifier.
|
||||
* pts: 64-bit number holding the state needed to fetch updates.
|
||||
"""
|
||||
__slots__ = ('channel_id', 'pts')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channel_id: int,
|
||||
pts: int
|
||||
):
|
||||
self.channel_id = channel_id
|
||||
self.pts = pts
|
||||
|
||||
|
||||
class Entity:
|
||||
"""
|
||||
Stores the information needed to use a certain user, chat or channel with the API.
|
||||
|
||||
* ty: 8-bit number indicating the type of the entity.
|
||||
* id: 64-bit number uniquely identifying the entity among those of the same type.
|
||||
* access_hash: 64-bit number needed to use this entity with the API.
|
||||
|
||||
You can rely on the ``ty`` value to be equal to the ASCII character one of:
|
||||
|
||||
* 'U' (85): this entity belongs to a :tl:`User` who is not a ``bot``.
|
||||
* 'B' (66): this entity belongs to a :tl:`User` who is a ``bot``.
|
||||
* 'G' (71): this entity belongs to a small group :tl:`Chat`.
|
||||
* 'C' (67): this entity belongs to a standard broadcast :tl:`Channel`.
|
||||
* 'M' (77): this entity belongs to a megagroup :tl:`Channel`.
|
||||
* 'E' (69): this entity belongs to an "enormous" "gigagroup" :tl:`Channel`.
|
||||
"""
|
||||
__slots__ = ('ty', 'id', 'access_hash')
|
||||
|
||||
USER = ord('U')
|
||||
BOT = ord('B')
|
||||
GROUP = ord('G')
|
||||
CHANNEL = ord('C')
|
||||
MEGAGROUP = ord('M')
|
||||
GIGAGROUP = ord('E')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ty: int,
|
||||
id: int,
|
||||
access_hash: int
|
||||
):
|
||||
self.ty = ty
|
||||
self.id = id
|
||||
self.access_hash = access_hash
|
||||
|
||||
|
||||
def canonical_entity_type(ty: int, *, _mapping={
|
||||
Entity.USER: Entity.USER,
|
||||
Entity.BOT: Entity.USER,
|
||||
Entity.GROUP: Entity.GROUP,
|
||||
Entity.CHANNEL: Entity.CHANNEL,
|
||||
Entity.MEGAGROUP: Entity.CHANNEL,
|
||||
Entity.GIGAGROUP: Entity.CHANNEL,
|
||||
}) -> int:
|
||||
"""
|
||||
Return the canonical version of an entity type.
|
||||
"""
|
||||
try:
|
||||
return _mapping[ty]
|
||||
except KeyError:
|
||||
ty = chr(ty) if isinstance(ty, int) else ty
|
||||
raise ValueError(f'entity type {ty!r} is not valid')
|
||||
|
||||
|
||||
def get_entity_type_group(ty: int, *, _mapping={
|
||||
Entity.USER: (Entity.USER, Entity.BOT),
|
||||
Entity.BOT: (Entity.USER, Entity.BOT),
|
||||
Entity.GROUP: (Entity.GROUP,),
|
||||
Entity.CHANNEL: (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP),
|
||||
Entity.MEGAGROUP: (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP),
|
||||
Entity.GIGAGROUP: (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP),
|
||||
}) -> Tuple[int]:
|
||||
"""
|
||||
Return the group where an entity type belongs to.
|
||||
"""
|
||||
try:
|
||||
return _mapping[ty]
|
||||
except KeyError:
|
||||
ty = chr(ty) if isinstance(ty, int) else ty
|
||||
raise ValueError(f'entity type {ty!r} is not valid')
|
Reference in New Issue
Block a user