Telethon/client/src/telethon/_impl/session/storage/sqlite.py
2023-10-02 23:26:40 +02:00

212 lines
6.9 KiB
Python

import sqlite3
from pathlib import Path
from typing import Optional, Union
from ..session import ChannelState, DataCenter, Session, UpdateState, User
from .storage import Storage
EXTENSION = ".session"
CURRENT_VERSION = 10
class SqliteSession(Storage):
"""
Session storage backed by SQLite.
SQLite is a reliable way to persist data to disk and offers file locking.
Paths without extension will have ``'.session'`` appended to them.
This is by convention, and to make it harder to commit session files to
an VCS by accident (adding ``*.session`` to ``.gitignore`` will catch them).
"""
def __init__(self, file: Union[str, Path]):
path = Path(file)
if not path.suffix:
path = path.with_suffix(EXTENSION)
self._path = path
self._conn: Optional[sqlite3.Connection] = None
async def load(self) -> Optional[Session]:
conn = self._current_conn()
c = conn.cursor()
with conn:
version = self._get_or_init_version(c)
if version < CURRENT_VERSION:
if version == 7:
session = self._load_v7(c)
else:
raise ValueError(
"only migration from sqlite session format 7 supported"
)
self._reset(c)
self._get_or_init_version(c)
self._save_v10(c, session)
return self._load_v10(c)
async def save(self, session: Session) -> None:
conn = self._current_conn()
with conn:
self._save_v10(conn.cursor(), session)
conn.close()
self._conn = None
async def delete(self) -> None:
if self._conn:
self._conn.close()
self._conn = None
self._path.unlink()
def _current_conn(self) -> sqlite3.Connection:
if self._conn is None:
self._conn = sqlite3.connect(self._path)
return self._conn
@staticmethod
def _load_v7(c: sqlite3.Cursor) -> Session:
# Session v7 format from telethon v1
c.execute("select dc_id, server_address, port, auth_key from sessions")
sessions = c.fetchall()
c.execute("select pts, qts, date, seq from update_state where id = 0")
state = c.fetchone()
c.execute("select id, pts from update_state where id != 0")
channelstate = c.fetchall()
return Session(
dcs=[
DataCenter(id=id, addr=f"{ip}:{port}", auth=auth)
for (id, ip, port, auth) in sessions
],
user=None,
state=UpdateState(
pts=state[0],
qts=state[1],
date=state[2],
seq=state[3],
channels=[ChannelState(id=id, pts=pts) for id, pts in channelstate],
),
)
@staticmethod
def _load_v10(c: sqlite3.Cursor) -> Session:
c.execute("select * from datacenter")
datacenter = c.fetchall()
c.execute("select * from user")
user = c.fetchone()
c.execute("select * from state")
state = c.fetchone()
c.execute("select * from channelstate")
channelstate = c.fetchall()
return Session(
dcs=[
DataCenter(id=id, addr=addr, auth=auth)
for (id, addr, auth) in datacenter
],
user=User(id=user[0], dc=user[1], bot=bool(user[2]), username=user[3])
if user
else None,
state=UpdateState(
pts=state[0],
qts=state[1],
date=state[2],
seq=state[3],
channels=[ChannelState(id=id, pts=pts) for id, pts in channelstate],
)
if state
else None,
)
@staticmethod
def _save_v10(c: sqlite3.Cursor, session: Session) -> None:
c.execute("delete from datacenter")
c.execute("delete from user")
c.execute("delete from state")
c.execute("delete from channelstate")
c.executemany(
"insert into datacenter values (?, ?, ?)",
[(dc.id, dc.addr, dc.auth) for dc in session.dcs],
)
if user := session.user:
c.execute(
"insert into user values (?, ?, ?, ?)",
(user.id, user.dc, int(user.bot), user.username),
)
if state := session.state:
c.execute(
"insert into state values (?, ?, ?, ?)",
(state.pts, state.qts, state.date, state.seq),
)
c.executemany(
"insert into channelstate values (?, ?)",
[(channel.id, channel.pts) for channel in state.channels],
)
@staticmethod
def _reset(c: sqlite3.Cursor) -> None:
safe_chars = "_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
c.execute("select name from sqlite_master where type='table'")
for (name,) in c.fetchall():
# Can't format arguments for table names. Regardless, it shouldn't
# be an SQL-injection because names come from `sqlite_master`.
# Just to be on the safe-side, check for r'\w+' nevertheless,
# avoiding referencing globals which could've been monkey-patched.
for char in name:
if char not in safe_chars or name.__len__() > 20:
raise ValueError(f"potentially unsafe table name: {name}")
c.execute(f"drop table {name}")
@staticmethod
def _get_or_init_version(c: sqlite3.Cursor) -> int:
c.execute(
"select name from sqlite_master where type='table' and name='version'"
)
if c.fetchone():
c.execute("select version from version")
tup = c.fetchone()
if tup and isinstance(tup[0], int):
return tup[0]
SqliteSession._reset(c)
SqliteSession._create_tables(c)
c.execute("insert into version values (?)", (CURRENT_VERSION,))
return CURRENT_VERSION
@staticmethod
def _create_tables(c: sqlite3.Cursor) -> None:
c.executescript(
"""
create table version (
version integer primary key
);
create table datacenter(
id integer primary key,
addr text not null,
auth blob
);
create table user(
id integer primary key,
dc integer not null,
bot integer not null,
username text
);
create table state(
pts integer not null,
qts integer not null,
date integer not null,
seq integer not null
);
create table channelstate(
id integer primary key,
pts integer not null
);
"""
)