Split Session into three parts and make a module for sessions

This commit is contained in:
Tulir Asokan
2018-03-01 23:34:32 +02:00
parent f09ab6c6b6
commit c5e6f7e265
6 changed files with 491 additions and 193 deletions

View File

@@ -0,0 +1,3 @@
from .abstract import Session
from .memory import MemorySession
from .sqlite import SQLiteSession

View File

@@ -0,0 +1,136 @@
from abc import ABC, abstractmethod
class Session(ABC):
@abstractmethod
def clone(self):
raise NotImplementedError
@abstractmethod
def set_dc(self, dc_id, server_address, port):
raise NotImplementedError
@property
@abstractmethod
def server_address(self):
raise NotImplementedError
@property
@abstractmethod
def port(self):
raise NotImplementedError
@property
@abstractmethod
def auth_key(self):
raise NotImplementedError
@auth_key.setter
@abstractmethod
def auth_key(self, value):
raise NotImplementedError
@property
@abstractmethod
def time_offset(self):
raise NotImplementedError
@time_offset.setter
@abstractmethod
def time_offset(self, value):
raise NotImplementedError
@property
@abstractmethod
def salt(self):
raise NotImplementedError
@salt.setter
@abstractmethod
def salt(self, value):
raise NotImplementedError
@property
@abstractmethod
def device_model(self):
raise NotImplementedError
@property
@abstractmethod
def system_version(self):
raise NotImplementedError
@property
@abstractmethod
def app_version(self):
raise NotImplementedError
@property
@abstractmethod
def lang_code(self):
raise NotImplementedError
@property
@abstractmethod
def system_lang_code(self):
raise NotImplementedError
@property
@abstractmethod
def report_errors(self):
raise NotImplementedError
@property
@abstractmethod
def sequence(self):
raise NotImplementedError
@property
@abstractmethod
def flood_sleep_threshold(self):
raise NotImplementedError
@abstractmethod
def close(self):
raise NotImplementedError
@abstractmethod
def save(self):
raise NotImplementedError
@abstractmethod
def delete(self):
raise NotImplementedError
@classmethod
@abstractmethod
def list_sessions(cls):
raise NotImplementedError
@abstractmethod
def get_new_msg_id(self):
raise NotImplementedError
@abstractmethod
def update_time_offset(self, correct_msg_id):
raise NotImplementedError
@abstractmethod
def generate_sequence(self, content_related):
raise NotImplementedError
@abstractmethod
def process_entities(self, tlo):
raise NotImplementedError
@abstractmethod
def get_input_entity(self, key):
raise NotImplementedError
@abstractmethod
def cache_file(self, md5_digest, file_size, instance):
raise NotImplementedError
@abstractmethod
def get_file(self, md5_digest, file_size, cls):
raise NotImplementedError

297
telethon/sessions/memory.py Normal file
View File

@@ -0,0 +1,297 @@
from enum import Enum
import time
import platform
from .. import utils
from .abstract import Session
from ..tl import TLObject
from ..tl.types import (
PeerUser, PeerChat, PeerChannel,
InputPeerUser, InputPeerChat, InputPeerChannel,
InputPhoto, InputDocument
)
class _SentFileType(Enum):
DOCUMENT = 0
PHOTO = 1
@staticmethod
def from_type(cls):
if cls == InputDocument:
return _SentFileType.DOCUMENT
elif cls == InputPhoto:
return _SentFileType.PHOTO
else:
raise ValueError('The cls must be either InputDocument/InputPhoto')
class MemorySession(Session):
def __init__(self):
self._dc_id = None
self._server_address = None
self._port = None
self._salt = None
self._auth_key = None
self._sequence = 0
self._last_msg_id = 0
self._time_offset = 0
self._flood_sleep_threshold = 60
system = platform.uname()
self._device_model = system.system or 'Unknown'
self._system_version = system.release or '1.0'
self._app_version = '1.0'
self._lang_code = 'en'
self._system_lang_code = self.lang_code
self._report_errors = True
self._flood_sleep_threshold = 60
self._files = {}
self._entities = set()
def clone(self):
cloned = MemorySession()
cloned._device_model = self.device_model
cloned._system_version = self.system_version
cloned._app_version = self.app_version
cloned._lang_code = self.lang_code
cloned._system_lang_code = self.system_lang_code
cloned._report_errors = self.report_errors
cloned._flood_sleep_threshold = self.flood_sleep_threshold
def set_dc(self, dc_id, server_address, port):
self._dc_id = dc_id
self._server_address = server_address
self._port = port
@property
def server_address(self):
return self._server_address
@property
def port(self):
return self._port
@property
def auth_key(self):
return self._auth_key
@auth_key.setter
def auth_key(self, value):
self._auth_key = value
@property
def time_offset(self):
return self._time_offset
@time_offset.setter
def time_offset(self, value):
self._time_offset = value
@property
def salt(self):
return self._salt
@salt.setter
def salt(self, value):
self._salt = value
@property
def device_model(self):
return self._device_model
@property
def system_version(self):
return self._system_version
@property
def app_version(self):
return self._app_version
@property
def lang_code(self):
return self._lang_code
@property
def system_lang_code(self):
return self._system_lang_code
@property
def report_errors(self):
return self._report_errors
@property
def sequence(self):
return self._sequence
@property
def flood_sleep_threshold(self):
return self._flood_sleep_threshold
def close(self):
pass
def save(self):
pass
def delete(self):
pass
@classmethod
def list_sessions(cls):
raise NotImplementedError
def get_new_msg_id(self):
"""Generates a new unique message ID based on the current
time (in ms) since epoch"""
now = time.time() + self._time_offset
nanoseconds = int((now - int(now)) * 1e+9)
new_msg_id = (int(now) << 32) | (nanoseconds << 2)
if self._last_msg_id >= new_msg_id:
new_msg_id = self._last_msg_id + 4
self._last_msg_id = new_msg_id
return new_msg_id
def update_time_offset(self, correct_msg_id):
now = int(time.time())
correct = correct_msg_id >> 32
self._time_offset = correct - now
self._last_msg_id = 0
def generate_sequence(self, content_related):
if content_related:
result = self._sequence * 2 + 1
self._sequence += 1
return result
else:
return self._sequence * 2
@staticmethod
def _entities_to_rows(tlo):
if not isinstance(tlo, TLObject) and utils.is_list_like(tlo):
# This may be a list of users already for instance
entities = tlo
else:
entities = []
if hasattr(tlo, 'chats') and utils.is_list_like(tlo.chats):
entities.extend(tlo.chats)
if hasattr(tlo, 'users') and utils.is_list_like(tlo.users):
entities.extend(tlo.users)
if not entities:
return
rows = [] # Rows to add (id, hash, username, phone, name)
for e in entities:
if not isinstance(e, TLObject):
continue
try:
p = utils.get_input_peer(e, allow_self=False)
marked_id = utils.get_peer_id(p)
except ValueError:
continue
if isinstance(p, (InputPeerUser, InputPeerChannel)):
if not p.access_hash:
# Some users and channels seem to be returned without
# an 'access_hash', meaning Telegram doesn't want you
# to access them. This is the reason behind ensuring
# that the 'access_hash' is non-zero. See issue #354.
# Note that this checks for zero or None, see #392.
continue
else:
p_hash = p.access_hash
elif isinstance(p, InputPeerChat):
p_hash = 0
else:
continue
username = getattr(e, 'username', None) or None
if username is not None:
username = username.lower()
phone = getattr(e, 'phone', None)
name = utils.get_display_name(e) or None
rows.append((marked_id, p_hash, username, phone, name))
return rows
def process_entities(self, tlo):
self._entities += set(self._entities_to_rows(tlo))
def get_entity_rows_by_phone(self, phone):
rows = [(id, hash) for id, hash, _, found_phone, _
in self._entities if found_phone == phone]
return rows[0] if rows else None
def get_entity_rows_by_username(self, username):
rows = [(id, hash) for id, hash, found_username, _, _
in self._entities if found_username == username]
return rows[0] if rows else None
def get_entity_rows_by_name(self, name):
rows = [(id, hash) for id, hash, _, _, found_name
in self._entities if found_name == name]
return rows[0] if rows else None
def get_entity_rows_by_id(self, id):
rows = [(id, hash) for found_id, hash, _, _, _
in self._entities if found_id == id]
return rows[0] if rows else None
def get_input_entity(self, key):
try:
if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd):
# hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel'))
# We already have an Input version, so nothing else required
return key
# Try to early return if this key can be casted as input peer
return utils.get_input_peer(key)
except (AttributeError, TypeError):
# Not a TLObject or can't be cast into InputPeer
if isinstance(key, TLObject):
key = utils.get_peer_id(key)
result = None
if isinstance(key, str):
phone = utils.parse_phone(key)
if phone:
result = self.get_entity_rows_by_phone(phone)
else:
username, _ = utils.parse_username(key)
if username:
result = self.get_entity_rows_by_username(username)
if isinstance(key, int):
result = self.get_entity_rows_by_id(key)
if not result and isinstance(key, str):
result = self.get_entity_rows_by_name(key)
if result:
i, h = result # unpack resulting tuple
i, k = utils.resolve_id(i) # removes the mark and returns kind
if k == PeerUser:
return InputPeerUser(i, h)
elif k == PeerChat:
return InputPeerChat(i)
elif k == PeerChannel:
return InputPeerChannel(i, h)
else:
raise ValueError('Could not find input entity with key ', key)
def cache_file(self, md5_digest, file_size, instance):
if not isinstance(instance, (InputDocument, InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance))
key = (md5_digest, file_size, _SentFileType.from_type(instance))
value = (instance.id, instance.access_hash)
self._files[key] = value
def get_file(self, md5_digest, file_size, cls):
key = (md5_digest, file_size, _SentFileType.from_type(cls))
try:
return self._files[key]
except KeyError:
return None

353
telethon/sessions/sqlite.py Normal file
View File

@@ -0,0 +1,353 @@
import json
import os
import platform
import sqlite3
import struct
import time
from base64 import b64decode
from os.path import isfile as file_exists
from threading import Lock, RLock
from .. import utils
from .abstract import Session
from .memory import MemorySession, _SentFileType
from ..crypto import AuthKey
from ..tl import TLObject
from ..tl.types import (
PeerUser, PeerChat, PeerChannel,
InputPeerUser, InputPeerChat, InputPeerChannel,
InputPhoto, InputDocument
)
EXTENSION = '.session'
CURRENT_VERSION = 3 # database version
class SQLiteSession(MemorySession):
"""This session contains the required information to login into your
Telegram account. NEVER give the saved JSON file to anyone, since
they would gain instant access to all your messages and contacts.
If you think the session has been compromised, close all the sessions
through an official Telegram client to revoke the authorization.
"""
def __init__(self, session_id):
super().__init__()
"""session_user_id should either be a string or another Session.
Note that if another session is given, only parameters like
those required to init a connection will be copied.
"""
# These values will NOT be saved
self.filename = ':memory:'
# For connection purposes
if isinstance(session_id, Session):
self._device_model = session_id.device_model
self._system_version = session_id.system_version
self._app_version = session_id.app_version
self._lang_code = session_id.lang_code
self._system_lang_code = session_id.system_lang_code
self._report_errors = session_id.report_errors
self._flood_sleep_threshold = session_id.flood_sleep_threshold
if isinstance(session_id, SQLiteSession):
self.save_entities = session_id.save_entities
else: # str / None
if session_id:
self.filename = session_id
if not self.filename.endswith(EXTENSION):
self.filename += EXTENSION
system = platform.uname()
self._device_model = system.system or 'Unknown'
self._system_version = system.release or '1.0'
self._app_version = '1.0' # '0' will provoke error
self._lang_code = 'en'
self._system_lang_code = self.lang_code
self._report_errors = True
self.save_entities = True
self._flood_sleep_threshold = 60
self.id = struct.unpack('q', os.urandom(8))[0]
self._sequence = 0
self.time_offset = 0
self._last_msg_id = 0 # Long
self.salt = 0 # Long
# Cross-thread safety
self._seq_no_lock = Lock()
self._msg_id_lock = Lock()
self._db_lock = RLock()
# These values will be saved
self._dc_id = 0
self._server_address = None
self._port = None
self._auth_key = None
# Migrating from .json -> SQL
entities = self._check_migrate_json()
self._conn = None
c = self._cursor()
c.execute("select name from sqlite_master "
"where type='table' and name='version'")
if c.fetchone():
# Tables already exist, check for the version
c.execute("select version from version")
version = c.fetchone()[0]
if version != CURRENT_VERSION:
self._upgrade_database(old=version)
c.execute("delete from version")
c.execute("insert into version values (?)", (CURRENT_VERSION,))
self.save()
# These values will be saved
c.execute('select * from sessions')
tuple_ = c.fetchone()
if tuple_:
self._dc_id, self._server_address, self._port, key, = tuple_
self._auth_key = AuthKey(data=key)
c.close()
else:
# Tables don't exist, create new ones
self._create_table(
c,
"version (version integer primary key)"
,
"""sessions (
dc_id integer primary key,
server_address text,
port integer,
auth_key blob
)"""
,
"""entities (
id integer primary key,
hash integer not null,
username text,
phone integer,
name text
)"""
,
"""sent_files (
md5_digest blob,
file_size integer,
type integer,
id integer,
hash integer,
primary key(md5_digest, file_size, type)
)"""
)
c.execute("insert into version values (?)", (CURRENT_VERSION,))
# Migrating from JSON -> new table and may have entities
if entities:
c.executemany(
'insert or replace into entities values (?,?,?,?,?)',
entities
)
self._update_session_table()
c.close()
self.save()
def clone(self):
return SQLiteSession(self)
def _check_migrate_json(self):
if file_exists(self.filename):
try:
with open(self.filename, encoding='utf-8') as f:
data = json.load(f)
self.delete() # Delete JSON file to create database
self._port = data.get('port', self._port)
self._server_address = \
data.get('server_address', self._server_address)
if data.get('auth_key_data', None) is not None:
key = b64decode(data['auth_key_data'])
self._auth_key = AuthKey(data=key)
rows = []
for p_id, p_hash in data.get('entities', []):
if p_hash is not None:
rows.append((p_id, p_hash, None, None, None))
return rows
except UnicodeDecodeError:
return [] # No entities
def _upgrade_database(self, old):
c = self._cursor()
# old == 1 doesn't have the old sent_files so no need to drop
if old == 2:
# Old cache from old sent_files lasts then a day anyway, drop
c.execute('drop table sent_files')
self._create_table(c, """sent_files (
md5_digest blob,
file_size integer,
type integer,
id integer,
hash integer,
primary key(md5_digest, file_size, type)
)""")
c.close()
@staticmethod
def _create_table(c, *definitions):
"""
Creates a table given its definition 'name (columns).
If the sqlite version is >= 3.8.2, it will use "without rowid".
See http://www.sqlite.org/releaselog/3_8_2.html.
"""
required = (3, 8, 2)
sqlite_v = tuple(int(x) for x in sqlite3.sqlite_version.split('.'))
extra = ' without rowid' if sqlite_v >= required else ''
for definition in definitions:
c.execute('create table {}{}'.format(definition, extra))
# Data from sessions should be kept as properties
# not to fetch the database every time we need it
def set_dc(self, dc_id, server_address, port):
self._dc_id = dc_id
self._server_address = server_address
self._port = port
self._update_session_table()
# Fetch the auth_key corresponding to this data center
c = self._cursor()
c.execute('select auth_key from sessions')
tuple_ = c.fetchone()
if tuple_ and tuple_[0]:
self._auth_key = AuthKey(data=tuple_[0])
else:
self._auth_key = None
c.close()
@Session.auth_key.setter
def auth_key(self, value):
self._auth_key = value
self._update_session_table()
def _update_session_table(self):
with self._db_lock:
c = self._cursor()
# While we can save multiple rows into the sessions table
# currently we only want to keep ONE as the tables don't
# tell us which auth_key's are usable and will work. Needs
# some more work before being able to save auth_key's for
# multiple DCs. Probably done differently.
c.execute('delete from sessions')
c.execute('insert or replace into sessions values (?,?,?,?)', (
self._dc_id,
self._server_address,
self._port,
self._auth_key.key if self._auth_key else b''
))
c.close()
def save(self):
"""Saves the current session object as session_user_id.session"""
with self._db_lock:
self._conn.commit()
def _cursor(self):
"""Asserts that the connection is open and returns a cursor"""
with self._db_lock:
if self._conn is None:
self._conn = sqlite3.connect(self.filename,
check_same_thread=False)
return self._conn.cursor()
def close(self):
"""Closes the connection unless we're working in-memory"""
if self.filename != ':memory:':
with self._db_lock:
if self._conn is not None:
self._conn.close()
self._conn = None
def delete(self):
"""Deletes the current session file"""
if self.filename == ':memory:':
return True
try:
os.remove(self.filename)
return True
except OSError:
return False
@classmethod
def list_sessions(cls):
"""Lists all the sessions of the users who have ever connected
using this client and never logged out
"""
return [os.path.splitext(os.path.basename(f))[0]
for f in os.listdir('.') if f.endswith(EXTENSION)]
# Entity processing
def process_entities(self, tlo):
"""Processes all the found entities on the given TLObject,
unless .enabled is False.
Returns True if new input entities were added.
"""
if not self.save_entities:
return
rows = self._entities_to_rows(tlo)
if not rows:
return
with self._db_lock:
self._cursor().executemany(
'insert or replace into entities values (?,?,?,?,?)', rows
)
self.save()
def _fetchone_entity(self, query, args):
c = self._cursor()
c.execute(query, args)
return c.fetchone()
def get_entity_rows_by_phone(self, phone):
return self._fetchone_entity(
'select id, hash from entities where phone=?', (phone,))
def get_entity_rows_by_username(self, username):
self._fetchone_entity('select id, hash from entities where username=?',
(username,))
def get_entity_rows_by_name(self, name):
self._fetchone_entity('select id, hash from entities where name=?',
(name,))
def get_entity_rows_by_id(self, id):
self._fetchone_entity('select id, hash from entities where id=?',
(id,))
# File processing
def get_file(self, md5_digest, file_size, cls):
tuple_ = self._cursor().execute(
'select id, hash from sent_files '
'where md5_digest = ? and file_size = ? and type = ?',
(md5_digest, file_size, _SentFileType.from_type(cls).value)
).fetchone()
if tuple_:
# Both allowed classes have (id, access_hash) as parameters
return cls(tuple_[0], tuple_[1])
def cache_file(self, md5_digest, file_size, instance):
if not isinstance(instance, (InputDocument, InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance))
with self._db_lock:
self._cursor().execute(
'insert or replace into sent_files values (?,?,?,?,?)', (
md5_digest, file_size,
_SentFileType.from_type(type(instance)).value,
instance.id, instance.access_hash
))
self.save()