import json import os import platform import sqlite3 import struct import time from base64 import b64decode from os.path import isfile as file_exists from threading import Lock from .entity_database import EntityDatabase from .. import helpers EXTENSION = '.session' CURRENT_VERSION = 1 # database version class Session: """This session contains the required information to login into your Telegram account. NEVER give the saved JSON file to anyone, since they would gain instant access to all your messages and contacts. If you think the session has been compromised, close all the sessions through an official Telegram client to revoke the authorization. """ def __init__(self, session_user_id): """session_user_id should either be a string or another Session. Note that if another session is given, only parameters like those required to init a connection will be copied. """ # These values will NOT be saved self.filename = ':memory:' if isinstance(session_user_id, Session): self.session_user_id = None # For connection purposes session = session_user_id self.device_model = session.device_model self.system_version = session.system_version self.app_version = session.app_version self.lang_code = session.lang_code self.system_lang_code = session.system_lang_code self.lang_pack = session.lang_pack self.report_errors = session.report_errors self.save_entities = session.save_entities self.flood_sleep_threshold = session.flood_sleep_threshold else: # str / None if session_user_id: self.filename = session_user_id if not self.filename.endswith(EXTENSION): self.filename += EXTENSION system = platform.uname() self.device_model = system.system if system.system else 'Unknown' self.system_version = system.release if system.release else '1.0' self.app_version = '1.0' # '0' will provoke error self.lang_code = 'en' self.system_lang_code = self.lang_code self.lang_pack = '' self.report_errors = True self.save_entities = True self.flood_sleep_threshold = 60 # These values will be saved self._server_address = None self._port = None self._auth_key = None self._layer = 0 self._salt = 0 # Signed long self.entities = EntityDatabase() # Known and cached entities # Cross-thread safety self._seq_no_lock = Lock() self._msg_id_lock = Lock() self._db_lock = Lock() # Migrating from .json -> SQL self._check_migrate_json() self._conn = sqlite3.connect(self.filename, check_same_thread=False) c = self._conn.cursor() c.execute("select name from sqlite_master " "where type='table' and name='version'") if c.fetchone(): # Tables already exist, check for the version c.execute("select version from version") version = c.fetchone()[0] if version != CURRENT_VERSION: self._upgrade_database(old=version) self.save() # These values will be saved c.execute('select * from sessions') self._server_address, self._port, key, \ self._layer, self._salt = c.fetchone() from ..crypto import AuthKey self._auth_key = AuthKey(data=key) c.close() else: # Tables don't exist, create new ones c.execute("create table version (version integer)") c.execute( """create table sessions ( server_address text, port integer, auth_key blob, layer integer, salt integer )""" ) c.execute( """create table entities ( id integer, hash integer, username text, phone integer, name text )""" ) c.execute("insert into version values (1)") c.close() self.save() self.id = helpers.generate_random_long(signed=True) self._sequence = 0 self.time_offset = 0 self._last_msg_id = 0 # Long def _check_migrate_json(self): if file_exists(self.filename): try: with open(self.filename, encoding='utf-8') as f: data = json.load(f) self._port = data.get('port', self._port) self._salt = data.get('salt', self._salt) # Keep while migrating from unsigned to signed salt if self._salt > 0: self._salt = struct.unpack( 'q', struct.pack('Q', self._salt))[0] self._layer = data.get('layer', self._layer) self._server_address = \ data.get('server_address', self._server_address) from ..crypto import AuthKey if data.get('auth_key_data', None) is not None: key = b64decode(data['auth_key_data']) self._auth_key = AuthKey(data=key) self.entities = EntityDatabase(data.get('entities', [])) self.delete() # Delete JSON file to create database except (UnicodeDecodeError, json.decoder.JSONDecodeError): pass def _upgrade_database(self, old): pass # Data from sessions should be kept as properties # not to fetch the database every time we need it @property def server_address(self): return self._server_address @server_address.setter def server_address(self, value): self._server_address = value self._update_session_table() @property def port(self): return self._port @port.setter def port(self, value): self._port = value self._update_session_table() @property def auth_key(self): return self._auth_key @auth_key.setter def auth_key(self, value): self._auth_key = value self._update_session_table() @property def layer(self): return self._layer @layer.setter def layer(self, value): self._layer = value self._update_session_table() @property def salt(self): return self._salt @salt.setter def salt(self, value): self._salt = value self._update_session_table() def _update_session_table(self): with self._db_lock: c = self._conn.cursor() c.execute('delete from sessions') c.execute('insert into sessions values (?,?,?,?,?)', ( self._server_address, self._port, self._auth_key.key if self._auth_key else b'', self._layer, self._salt )) c.close() def save(self): """Saves the current session object as session_user_id.session""" with self._db_lock: self._conn.commit() def delete(self): """Deletes the current session file""" if self.filename == ':memory:': return True try: os.remove(self.filename) return True except OSError: return False @staticmethod def list_sessions(): """Lists all the sessions of the users who have ever connected using this client and never logged out """ return [os.path.splitext(os.path.basename(f))[0] for f in os.listdir('.') if f.endswith(EXTENSION)] def generate_sequence(self, content_related): """Thread safe method to generates the next sequence number, based on whether it was confirmed yet or not. Note that if confirmed=True, the sequence number will be increased by one too """ with self._seq_no_lock: if content_related: result = self._sequence * 2 + 1 self._sequence += 1 return result else: return self._sequence * 2 def get_new_msg_id(self): """Generates a new unique message ID based on the current time (in ms) since epoch""" # Refer to mtproto_plain_sender.py for the original method now = time.time() nanoseconds = int((now - int(now)) * 1e+9) # "message identifiers are divisible by 4" new_msg_id = (int(now) << 32) | (nanoseconds << 2) with self._msg_id_lock: if self._last_msg_id >= new_msg_id: new_msg_id = self._last_msg_id + 4 self._last_msg_id = new_msg_id return new_msg_id def update_time_offset(self, correct_msg_id): """Updates the time offset based on a known correct message ID""" now = int(time.time()) correct = correct_msg_id >> 32 self.time_offset = correct - now def process_entities(self, tlobject): try: if self.entities.process(tlobject): self.save() # Save if any new entities got added except: pass