mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-06-15 17:46:51 +00:00
212 lines
6.9 KiB
Python
212 lines
6.9 KiB
Python
import sqlite3
|
|
from pathlib import Path
|
|
from typing import Optional, Union
|
|
|
|
from ..session import ChannelState, DataCenter, Session, UpdateState, User
|
|
from .storage import Storage
|
|
|
|
EXTENSION = ".session"
|
|
CURRENT_VERSION = 10
|
|
|
|
|
|
class SqliteSession(Storage):
|
|
"""
|
|
Session storage backed by SQLite.
|
|
|
|
SQLite is a reliable way to persist data to disk and offers file locking.
|
|
|
|
Paths without extension will have ``'.session'`` appended to them.
|
|
This is by convention, and to make it harder to commit session files to
|
|
an VCS by accident (adding ``*.session`` to ``.gitignore`` will catch them).
|
|
"""
|
|
|
|
def __init__(self, file: Union[str, Path]):
|
|
path = Path(file)
|
|
if not path.suffix:
|
|
path = path.with_suffix(EXTENSION)
|
|
|
|
self._path = path
|
|
self._conn: Optional[sqlite3.Connection] = None
|
|
|
|
async def load(self) -> Optional[Session]:
|
|
conn = self._current_conn()
|
|
|
|
c = conn.cursor()
|
|
with conn:
|
|
version = self._get_or_init_version(c)
|
|
if version < CURRENT_VERSION:
|
|
if version == 7:
|
|
session = self._load_v7(c)
|
|
else:
|
|
raise ValueError(
|
|
"only migration from sqlite session format 7 supported"
|
|
)
|
|
|
|
self._reset(c)
|
|
self._get_or_init_version(c)
|
|
self._save_v10(c, session)
|
|
|
|
return self._load_v10(c)
|
|
|
|
async def save(self, session: Session) -> None:
|
|
conn = self._current_conn()
|
|
with conn:
|
|
self._save_v10(conn.cursor(), session)
|
|
conn.close()
|
|
self._conn = None
|
|
|
|
async def delete(self) -> None:
|
|
if self._conn:
|
|
self._conn.close()
|
|
self._conn = None
|
|
self._path.unlink()
|
|
|
|
def _current_conn(self) -> sqlite3.Connection:
|
|
if self._conn is None:
|
|
self._conn = sqlite3.connect(self._path)
|
|
|
|
return self._conn
|
|
|
|
@staticmethod
|
|
def _load_v7(c: sqlite3.Cursor) -> Session:
|
|
# Session v7 format from telethon v1
|
|
c.execute("select dc_id, server_address, port, auth_key from sessions")
|
|
sessions = c.fetchall()
|
|
c.execute("select pts, qts, date, seq from update_state where id = 0")
|
|
state = c.fetchone()
|
|
c.execute("select id, pts from update_state where id != 0")
|
|
channelstate = c.fetchall()
|
|
|
|
return Session(
|
|
dcs=[
|
|
DataCenter(id=id, addr=f"{ip}:{port}", auth=auth)
|
|
for (id, ip, port, auth) in sessions
|
|
],
|
|
user=None,
|
|
state=UpdateState(
|
|
pts=state[0],
|
|
qts=state[1],
|
|
date=state[2],
|
|
seq=state[3],
|
|
channels=[ChannelState(id=id, pts=pts) for id, pts in channelstate],
|
|
),
|
|
)
|
|
|
|
@staticmethod
|
|
def _load_v10(c: sqlite3.Cursor) -> Session:
|
|
c.execute("select * from datacenter")
|
|
datacenter = c.fetchall()
|
|
c.execute("select * from user")
|
|
user = c.fetchone()
|
|
c.execute("select * from state")
|
|
state = c.fetchone()
|
|
c.execute("select * from channelstate")
|
|
channelstate = c.fetchall()
|
|
|
|
return Session(
|
|
dcs=[
|
|
DataCenter(id=id, addr=addr, auth=auth)
|
|
for (id, addr, auth) in datacenter
|
|
],
|
|
user=User(id=user[0], dc=user[1], bot=bool(user[2]), username=user[3])
|
|
if user
|
|
else None,
|
|
state=UpdateState(
|
|
pts=state[0],
|
|
qts=state[1],
|
|
date=state[2],
|
|
seq=state[3],
|
|
channels=[ChannelState(id=id, pts=pts) for id, pts in channelstate],
|
|
)
|
|
if state
|
|
else None,
|
|
)
|
|
|
|
@staticmethod
|
|
def _save_v10(c: sqlite3.Cursor, session: Session) -> None:
|
|
c.execute("delete from datacenter")
|
|
c.execute("delete from user")
|
|
c.execute("delete from state")
|
|
c.execute("delete from channelstate")
|
|
c.executemany(
|
|
"insert into datacenter values (?, ?, ?)",
|
|
[(dc.id, dc.addr, dc.auth) for dc in session.dcs],
|
|
)
|
|
if user := session.user:
|
|
c.execute(
|
|
"insert into user values (?, ?, ?, ?)",
|
|
(user.id, user.dc, int(user.bot), user.username),
|
|
)
|
|
if state := session.state:
|
|
c.execute(
|
|
"insert into state values (?, ?, ?, ?)",
|
|
(state.pts, state.qts, state.date, state.seq),
|
|
)
|
|
c.executemany(
|
|
"insert into channelstate values (?, ?)",
|
|
[(channel.id, channel.pts) for channel in state.channels],
|
|
)
|
|
|
|
@staticmethod
|
|
def _reset(c: sqlite3.Cursor) -> None:
|
|
safe_chars = "_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
|
|
|
c.execute("select name from sqlite_master where type='table'")
|
|
for (name,) in c.fetchall():
|
|
# Can't format arguments for table names. Regardless, it shouldn't
|
|
# be an SQL-injection because names come from `sqlite_master`.
|
|
# Just to be on the safe-side, check for r'\w+' nevertheless,
|
|
# avoiding referencing globals which could've been monkey-patched.
|
|
for char in name:
|
|
if char not in safe_chars or name.__len__() > 20:
|
|
raise ValueError(f"potentially unsafe table name: {name}")
|
|
|
|
c.execute(f"drop table {name}")
|
|
|
|
@staticmethod
|
|
def _get_or_init_version(c: sqlite3.Cursor) -> int:
|
|
c.execute(
|
|
"select name from sqlite_master where type='table' and name='version'"
|
|
)
|
|
if c.fetchone():
|
|
c.execute("select version from version")
|
|
tup = c.fetchone()
|
|
if tup and isinstance(tup[0], int):
|
|
return tup[0]
|
|
SqliteSession._reset(c)
|
|
|
|
SqliteSession._create_tables(c)
|
|
c.execute("insert into version values (?)", (CURRENT_VERSION,))
|
|
return CURRENT_VERSION
|
|
|
|
@staticmethod
|
|
def _create_tables(c: sqlite3.Cursor) -> None:
|
|
c.executescript(
|
|
"""
|
|
create table version (
|
|
version integer primary key
|
|
);
|
|
create table datacenter(
|
|
id integer primary key,
|
|
addr text not null,
|
|
auth blob
|
|
);
|
|
create table user(
|
|
id integer primary key,
|
|
dc integer not null,
|
|
bot integer not null,
|
|
username text
|
|
);
|
|
create table state(
|
|
pts integer not null,
|
|
qts integer not null,
|
|
date integer not null,
|
|
seq integer not null
|
|
);
|
|
create table channelstate(
|
|
id integer primary key,
|
|
pts integer not null
|
|
);
|
|
"""
|
|
)
|