Correct privacy on sessions module

This commit is contained in:
Lonami Exo
2021-09-19 18:24:16 +02:00
parent 26f6c62ce4
commit cfe47a0434
7 changed files with 12 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
from .abstract import Session
from .memory import MemorySession
from .sqlite import SQLiteSession
from .string import StringSession

View File

@@ -0,0 +1,92 @@
from .types import DataCenter, ChannelState, SessionState, Entity
from abc import ABC, abstractmethod
from typing import List, Optional
class Session(ABC):
@abstractmethod
async def insert_dc(self, dc: DataCenter):
"""
Store a new or update an existing `DataCenter` with matching ``id``.
"""
raise NotImplementedError
@abstractmethod
async def get_all_dc(self) -> List[DataCenter]:
"""
Get a list of all currently-stored `DataCenter`. Should not contain duplicate ``id``.
"""
raise NotImplementedError
@abstractmethod
async def set_state(self, state: SessionState):
"""
Set the state about the current session.
"""
raise NotImplementedError
@abstractmethod
async def get_state(self) -> Optional[SessionState]:
"""
Get the state about the current session.
"""
raise NotImplementedError
@abstractmethod
async def insert_channel_state(self, state: ChannelState):
"""
Store a new or update an existing `ChannelState` with matching ``id``.
"""
raise NotImplementedError
@abstractmethod
async def get_all_channel_states(self) -> List[ChannelState]:
"""
Get a list of all currently-stored `ChannelState`. Should not contain duplicate ``id``.
"""
raise NotImplementedError
@abstractmethod
async def insert_entities(self, entities: List[Entity]):
"""
Store new or update existing `Entity` with matching ``id``.
Entities should be saved on a best-effort. It is okay to not save them, although the
library may need to do extra work if a previously-saved entity is missing, or even be
unable to continue without the entity.
"""
raise NotImplementedError
@abstractmethod
async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
"""
Get the `Entity` with matching ``ty`` and ``id``.
The following groups of ``ty`` should be treated to be equivalent, that is, for a given
``ty`` and ``id``, if the ``ty`` is in a given group, a matching ``access_hash`` with
that ``id`` from within any ``ty`` in that group should be returned.
* ``'U'`` and ``'B'`` (user and bot).
* ``'G'`` (small group chat).
* ``'C'``, ``'M'`` and ``'E'`` (broadcast channel, megagroup channel, and gigagroup channel).
For example, if a ``ty`` representing a bot is stored but the asking ``ty`` is a user,
the corresponding ``access_hash`` should still be returned.
You may use `types.canonical_entity_type` to find out the canonical type.
A ``ty`` with the value of ``None`` should be treated as "any entity with matching ID".
"""
raise NotImplementedError
@abstractmethod
async def save(self):
"""
Save the session.
May do nothing if the other methods already saved when they were called.
May return custom data when manual saving is intended.
"""
raise NotImplementedError

View File

@@ -0,0 +1,47 @@
from .types import DataCenter, ChannelState, SessionState, Entity
from .abstract import Session
from .._misc import utils, tlobject
from .. import _tl
from typing import List, Optional
class MemorySession(Session):
__slots__ = ('dcs', 'state', 'channel_states', 'entities')
def __init__(self):
self.dcs = {}
self.state = None
self.channel_states = {}
self.entities = {}
async def insert_dc(self, dc: DataCenter):
self.dcs[dc.id] = dc
async def get_all_dc(self) -> List[DataCenter]:
return list(self.dcs.values())
async def set_state(self, state: SessionState):
self.state = state
async def get_state(self) -> Optional[SessionState]:
return self.state
async def insert_channel_state(self, state: ChannelState):
self.channel_states[state.channel_id] = state
async def get_all_channel_states(self) -> List[ChannelState]:
return list(self.channel_states.values())
async def insert_entities(self, entities: List[Entity]):
self.entities.update((e.id, (e.ty, e.access_hash)) for e in entities)
async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
try:
ty, access_hash = self.entities[id]
return Entity(ty, id, access_hash)
except KeyError:
return None
async def save(self):
pass

View File

@@ -0,0 +1,284 @@
import datetime
import os
import time
import ipaddress
from typing import Optional, List
from .abstract import Session
from .._misc import utils
from .. import _tl
from .types import DataCenter, ChannelState, SessionState, Entity
try:
import sqlite3
sqlite3_err = None
except ImportError as e:
sqlite3 = None
sqlite3_err = type(e)
EXTENSION = '.session'
CURRENT_VERSION = 8 # database version
class SQLiteSession(Session):
"""
This session contains the required information to login into your
Telegram account. NEVER give the saved session file to anyone, since
they would gain instant access to all your messages and contacts.
If you think the session has been compromised, close all the sessions
through an official Telegram client to revoke the authorization.
"""
def __init__(self, session_id=None):
if sqlite3 is None:
raise sqlite3_err
super().__init__()
self.filename = ':memory:'
self.save_entities = True
if session_id:
self.filename = session_id
if not self.filename.endswith(EXTENSION):
self.filename += EXTENSION
self._conn = None
c = self._cursor()
c.execute("select name from sqlite_master "
"where type='table' and name='version'")
if c.fetchone():
# Tables already exist, check for the version
c.execute("select version from version")
version = c.fetchone()[0]
if version < CURRENT_VERSION:
self._upgrade_database(old=version)
c.execute("delete from version")
c.execute("insert into version values (?)", (CURRENT_VERSION,))
self._conn.commit()
else:
# Tables don't exist, create new ones
self._create_table(c, 'version (version integer primary key)')
self._mk_tables(c)
c.execute("insert into version values (?)", (CURRENT_VERSION,))
self._conn.commit()
# Must have committed or else the version will not have been updated while new tables
# exist, leading to a half-upgraded state.
c.close()
def _upgrade_database(self, old):
c = self._cursor()
if old == 1:
old += 1
# old == 1 doesn't have the old sent_files so no need to drop
if old == 2:
old += 1
# Old cache from old sent_files lasts then a day anyway, drop
c.execute('drop table sent_files')
self._create_table(c, """sent_files (
md5_digest blob,
file_size integer,
type integer,
id integer,
hash integer,
primary key(md5_digest, file_size, type)
)""")
if old == 3:
old += 1
self._create_table(c, """update_state (
id integer primary key,
pts integer,
qts integer,
date integer,
seq integer
)""")
if old == 4:
old += 1
c.execute("alter table sessions add column takeout_id integer")
if old == 5:
# Not really any schema upgrade, but potentially all access
# hashes for User and Channel are wrong, so drop them off.
old += 1
c.execute('delete from entities')
if old == 6:
old += 1
c.execute("alter table entities add column date integer")
if old == 7:
self._mk_tables(c)
c.execute('''
insert into datacenter (id, ipv4, ipv6, port, auth)
select dc_id, server_address, server_address, port, auth_key
from sessions
''')
c.execute('''
insert into session (user_id, dc_id, bot, pts, qts, date, seq, takeout_id)
select
0,
s.dc_id,
0,
coalesce(u.pts, 0),
coalesce(u.qts, 0),
coalesce(u.date, 0),
coalesce(u.seq, 0),
s.takeout_id
from sessions s
left join update_state u on u.id = 0
limit 1
''')
c.execute('''
insert into entity (id, access_hash, ty)
select
case
when id < -1000000000000 then -(id + 1000000000000)
when id < 0 then -id
else id
end,
hash,
case
when id < -1000000000000 then 67
when id < 0 then 71
else 85
end
from entities
''')
c.execute('drop table sessions')
c.execute('drop table entities')
c.execute('drop table sent_files')
c.execute('drop table update_state')
def _mk_tables(self, c):
self._create_table(
c,
'''datacenter (
id integer primary key,
ipv4 text not null,
ipv6 text,
port integer not null,
auth blob not null
)''',
'''session (
user_id integer primary key,
dc_id integer not null,
bot integer not null,
pts integer not null,
qts integer not null,
date integer not null,
seq integer not null,
takeout_id integer
)''',
'''channel (
channel_id integer primary key,
pts integer not null
)''',
'''entity (
id integer primary key,
access_hash integer not null,
ty integer not null
)''',
)
async def insert_dc(self, dc: DataCenter):
self._execute(
'insert or replace into datacenter values (?,?,?,?,?)',
dc.id,
str(ipaddress.ip_address(dc.ipv4)),
str(ipaddress.ip_address(dc.ipv6)) if dc.ipv6 else None,
dc.port,
dc.auth
)
async def get_all_dc(self) -> List[DataCenter]:
c = self._cursor()
res = []
for (id, ipv4, ipv6, port, auth) in c.execute('select * from datacenter'):
res.append(DataCenter(
id=id,
ipv4=int(ipaddress.ip_address(ipv4)),
ipv6=int(ipaddress.ip_address(ipv6)) if ipv6 else None,
port=port,
auth=auth,
))
return res
async def set_state(self, state: SessionState):
c = self._cursor()
try:
self._execute('delete from session')
self._execute(
'insert into session values (?,?,?,?,?,?,?,?)',
state.user_id,
state.dc_id,
int(state.bot),
state.pts,
state.qts,
state.date,
state.seq,
state.takeout_id,
)
finally:
c.close()
async def get_state(self) -> Optional[SessionState]:
row = self._execute('select * from session')
return SessionState(*row) if row else None
async def insert_channel_state(self, state: ChannelState):
self._execute(
'insert or replace into channel values (?,?)',
state.channel_id,
state.pts,
)
async def get_all_channel_states(self) -> List[ChannelState]:
c = self._cursor()
try:
return [
ChannelState(*row)
for row in c.execute('select * from channel')
]
finally:
c.close()
async def insert_entities(self, entities: List[Entity]):
c = self._cursor()
try:
c.executemany(
'insert or replace into entity values (?,?,?)',
[(e.id, e.access_hash, e.ty) for e in entities]
)
finally:
c.close()
async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
row = self._execute('select ty, id, access_hash from entity where id = ?', id)
return Entity(*row) if row else None
async def save(self):
# This is a no-op if there are no changes to commit, so there's
# no need for us to keep track of an "unsaved changes" variable.
if self._conn is not None:
self._conn.commit()
@staticmethod
def _create_table(c, *definitions):
for definition in definitions:
c.execute('create table {}'.format(definition))
def _cursor(self):
"""Asserts that the connection is open and returns a cursor"""
if self._conn is None:
self._conn = sqlite3.connect(self.filename,
check_same_thread=False)
return self._conn.cursor()
def _execute(self, stmt, *values):
"""
Gets a cursor, executes `stmt` and closes the cursor,
fetching one row afterwards and returning its result.
"""
c = self._cursor()
try:
return c.execute(stmt, values).fetchone()
finally:
c.close()

View File

@@ -0,0 +1,88 @@
import base64
import ipaddress
import struct
from .abstract import Session
from .memory import MemorySession
from .types import DataCenter, ChannelState, SessionState, Entity
_STRUCT_PREFORMAT = '>B{}sH256s'
CURRENT_VERSION = '1'
class StringSession(MemorySession):
"""
This session file can be easily saved and loaded as a string. According
to the initial design, it contains only the data that is necessary for
successful connection and authentication, so takeout ID is not stored.
It is thought to be used where you don't want to create any on-disk
files but would still like to be able to save and load existing sessions
by other means.
You can use custom `encode` and `decode` functions, if present:
* `encode` definition must be ``def encode(value: bytes) -> str:``.
* `decode` definition must be ``def decode(value: str) -> bytes:``.
"""
def __init__(self, string: str = None):
super().__init__()
if string:
if string[0] != CURRENT_VERSION:
raise ValueError('Not a valid string')
string = string[1:]
ip_len = 4 if len(string) == 352 else 16
dc_id, ip, port, key = struct.unpack(
_STRUCT_PREFORMAT.format(ip_len), StringSession.decode(string))
self.state = SessionState(
dc_id=dc_id,
user_id=0,
bot=False,
pts=0,
qts=0,
date=0,
seq=0,
takeout_id=0
)
if ip_len == 4:
ipv4 = int.from_bytes(ip, 'big', False)
ipv6 = None
else:
ipv4 = None
ipv6 = int.from_bytes(ip, 'big', signed=False)
self.dcs[dc_id] = DataCenter(
id=dc_id,
ipv4=ipv4,
ipv6=ipv6,
port=port,
auth=key
)
@staticmethod
def encode(x: bytes) -> str:
return base64.urlsafe_b64encode(x).decode('ascii')
@staticmethod
def decode(x: str) -> bytes:
return base64.urlsafe_b64decode(x)
def save(self: Session):
if not self.state:
return ''
if self.state.ipv6 is not None:
ip = self.state.ipv6.to_bytes(16, 'big', signed=False)
else:
ip = self.state.ipv6.to_bytes(4, 'big', signed=False)
return CURRENT_VERSION + StringSession.encode(struct.pack(
_STRUCT_PREFORMAT.format(len(ip)),
self.state.dc_id,
ip,
self.state.port,
self.dcs[self.state.dc_id].auth
))

157
telethon/_sessions/types.py Normal file
View File

@@ -0,0 +1,157 @@
from typing import Optional, Tuple
class DataCenter:
"""
Stores the information needed to connect to a datacenter.
* id: 32-bit number representing the datacenter identifier as given by Telegram.
* ipv4 and ipv6: 32-bit or 128-bit number storing the IP address of the datacenter.
* port: 16-bit number storing the port number needed to connect to the datacenter.
* bytes: arbitrary binary payload needed to authenticate to the datacenter.
"""
__slots__ = ('id', 'ipv4', 'ipv6', 'port', 'auth')
def __init__(
self,
id: int,
ipv4: int,
ipv6: Optional[int],
port: int,
auth: bytes
):
self.id = id
self.ipv4 = ipv4
self.ipv6 = ipv6
self.port = port
self.auth = auth
class SessionState:
"""
Stores the information needed to fetch updates and about the current user.
* user_id: 64-bit number representing the user identifier.
* dc_id: 32-bit number relating to the datacenter identifier where the user is.
* bot: is the logged-in user a bot?
* pts: 64-bit number holding the state needed to fetch updates.
* qts: alternative 64-bit number holding the state needed to fetch updates.
* date: 64-bit number holding the date needed to fetch updates.
* seq: 64-bit-number holding the sequence number needed to fetch updates.
* takeout_id: 64-bit-number holding the identifier of the current takeout session.
Note that some of the numbers will only use 32 out of the 64 available bits.
However, for future-proofing reasons, we recommend you pretend they are 64-bit long.
"""
__slots__ = ('user_id', 'dc_id', 'bot', 'pts', 'qts', 'date', 'seq', 'takeout_id')
def __init__(
self,
user_id: int,
dc_id: int,
bot: bool,
pts: int,
qts: int,
date: int,
seq: int,
takeout_id: Optional[int],
):
self.user_id = user_id
self.dc_id = dc_id
self.bot = bot
self.pts = pts
self.qts = qts
self.date = date
self.seq = seq
self.takeout_id = takeout_id
class ChannelState:
"""
Stores the information needed to fetch updates from a channel.
* channel_id: 64-bit number representing the channel identifier.
* pts: 64-bit number holding the state needed to fetch updates.
"""
__slots__ = ('channel_id', 'pts')
def __init__(
self,
channel_id: int,
pts: int
):
self.channel_id = channel_id
self.pts = pts
class Entity:
"""
Stores the information needed to use a certain user, chat or channel with the API.
* ty: 8-bit number indicating the type of the entity.
* id: 64-bit number uniquely identifying the entity among those of the same type.
* access_hash: 64-bit number needed to use this entity with the API.
You can rely on the ``ty`` value to be equal to the ASCII character one of:
* 'U' (85): this entity belongs to a :tl:`User` who is not a ``bot``.
* 'B' (66): this entity belongs to a :tl:`User` who is a ``bot``.
* 'G' (71): this entity belongs to a small group :tl:`Chat`.
* 'C' (67): this entity belongs to a standard broadcast :tl:`Channel`.
* 'M' (77): this entity belongs to a megagroup :tl:`Channel`.
* 'E' (69): this entity belongs to an "enormous" "gigagroup" :tl:`Channel`.
"""
__slots__ = ('ty', 'id', 'access_hash')
USER = ord('U')
BOT = ord('B')
GROUP = ord('G')
CHANNEL = ord('C')
MEGAGROUP = ord('M')
GIGAGROUP = ord('E')
def __init__(
self,
ty: int,
id: int,
access_hash: int
):
self.ty = ty
self.id = id
self.access_hash = access_hash
def canonical_entity_type(ty: int, *, _mapping={
Entity.USER: Entity.USER,
Entity.BOT: Entity.USER,
Entity.GROUP: Entity.GROUP,
Entity.CHANNEL: Entity.CHANNEL,
Entity.MEGAGROUP: Entity.CHANNEL,
Entity.GIGAGROUP: Entity.CHANNEL,
}) -> int:
"""
Return the canonical version of an entity type.
"""
try:
return _mapping[ty]
except KeyError:
ty = chr(ty) if isinstance(ty, int) else ty
raise ValueError(f'entity type {ty!r} is not valid')
def get_entity_type_group(ty: int, *, _mapping={
Entity.USER: (Entity.USER, Entity.BOT),
Entity.BOT: (Entity.USER, Entity.BOT),
Entity.GROUP: (Entity.GROUP,),
Entity.CHANNEL: (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP),
Entity.MEGAGROUP: (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP),
Entity.GIGAGROUP: (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP),
}) -> Tuple[int]:
"""
Return the group where an entity type belongs to.
"""
try:
return _mapping[ty]
except KeyError:
ty = chr(ty) if isinstance(ty, int) else ty
raise ValueError(f'entity type {ty!r} is not valid')