diff --git a/telethon/network/mtproto_sender.py b/telethon/network/mtproto_sender.py index 43b5e803..cbcdc76d 100644 --- a/telethon/network/mtproto_sender.py +++ b/telethon/network/mtproto_sender.py @@ -402,13 +402,13 @@ class MtProtoSender: elif bad_msg.error_code == 32: # msg_seqno too low, so just pump it up by some "large" amount # TODO A better fix would be to start with a new fresh session ID - self.session._sequence += 64 + self.session.sequence += 64 __log__.info('Attempting to set the right higher sequence') self._resend_request(bad_msg.bad_msg_id) return True elif bad_msg.error_code == 33: # msg_seqno too high never seems to happen but just in case - self.session._sequence -= 16 + self.session.sequence -= 16 __log__.info('Attempting to set the right lower sequence') self._resend_request(bad_msg.bad_msg_id) return True diff --git a/telethon/sessions/__init__.py b/telethon/sessions/__init__.py new file mode 100644 index 00000000..af3423f3 --- /dev/null +++ b/telethon/sessions/__init__.py @@ -0,0 +1,3 @@ +from .abstract import Session +from .memory import MemorySession +from .sqlite import SQLiteSession diff --git a/telethon/sessions/abstract.py b/telethon/sessions/abstract.py new file mode 100644 index 00000000..c7392ffc --- /dev/null +++ b/telethon/sessions/abstract.py @@ -0,0 +1,136 @@ +from abc import ABC, abstractmethod + + +class Session(ABC): + @abstractmethod + def clone(self): + raise NotImplementedError + + @abstractmethod + def set_dc(self, dc_id, server_address, port): + raise NotImplementedError + + @property + @abstractmethod + def server_address(self): + raise NotImplementedError + + @property + @abstractmethod + def port(self): + raise NotImplementedError + + @property + @abstractmethod + def auth_key(self): + raise NotImplementedError + + @auth_key.setter + @abstractmethod + def auth_key(self, value): + raise NotImplementedError + + @property + @abstractmethod + def time_offset(self): + raise NotImplementedError + + @time_offset.setter + @abstractmethod + def time_offset(self, value): + raise NotImplementedError + + @property + @abstractmethod + def salt(self): + raise NotImplementedError + + @salt.setter + @abstractmethod + def salt(self, value): + raise NotImplementedError + + @property + @abstractmethod + def device_model(self): + raise NotImplementedError + + @property + @abstractmethod + def system_version(self): + raise NotImplementedError + + @property + @abstractmethod + def app_version(self): + raise NotImplementedError + + @property + @abstractmethod + def lang_code(self): + raise NotImplementedError + + @property + @abstractmethod + def system_lang_code(self): + raise NotImplementedError + + @property + @abstractmethod + def report_errors(self): + raise NotImplementedError + + @property + @abstractmethod + def sequence(self): + raise NotImplementedError + + @property + @abstractmethod + def flood_sleep_threshold(self): + raise NotImplementedError + + @abstractmethod + def close(self): + raise NotImplementedError + + @abstractmethod + def save(self): + raise NotImplementedError + + @abstractmethod + def delete(self): + raise NotImplementedError + + @classmethod + @abstractmethod + def list_sessions(cls): + raise NotImplementedError + + @abstractmethod + def get_new_msg_id(self): + raise NotImplementedError + + @abstractmethod + def update_time_offset(self, correct_msg_id): + raise NotImplementedError + + @abstractmethod + def generate_sequence(self, content_related): + raise NotImplementedError + + @abstractmethod + def process_entities(self, tlo): + raise NotImplementedError + + @abstractmethod + def get_input_entity(self, key): + raise NotImplementedError + + @abstractmethod + def cache_file(self, md5_digest, file_size, instance): + raise NotImplementedError + + @abstractmethod + def get_file(self, md5_digest, file_size, cls): + raise NotImplementedError diff --git a/telethon/sessions/memory.py b/telethon/sessions/memory.py new file mode 100644 index 00000000..66558829 --- /dev/null +++ b/telethon/sessions/memory.py @@ -0,0 +1,297 @@ +from enum import Enum +import time +import platform + +from .. import utils +from .abstract import Session +from ..tl import TLObject + +from ..tl.types import ( + PeerUser, PeerChat, PeerChannel, + InputPeerUser, InputPeerChat, InputPeerChannel, + InputPhoto, InputDocument +) + + +class _SentFileType(Enum): + DOCUMENT = 0 + PHOTO = 1 + + @staticmethod + def from_type(cls): + if cls == InputDocument: + return _SentFileType.DOCUMENT + elif cls == InputPhoto: + return _SentFileType.PHOTO + else: + raise ValueError('The cls must be either InputDocument/InputPhoto') + + +class MemorySession(Session): + def __init__(self): + self._dc_id = None + self._server_address = None + self._port = None + self._salt = None + self._auth_key = None + self._sequence = 0 + self._last_msg_id = 0 + self._time_offset = 0 + self._flood_sleep_threshold = 60 + + system = platform.uname() + self._device_model = system.system or 'Unknown' + self._system_version = system.release or '1.0' + self._app_version = '1.0' + self._lang_code = 'en' + self._system_lang_code = self.lang_code + self._report_errors = True + self._flood_sleep_threshold = 60 + + self._files = {} + self._entities = set() + + def clone(self): + cloned = MemorySession() + cloned._device_model = self.device_model + cloned._system_version = self.system_version + cloned._app_version = self.app_version + cloned._lang_code = self.lang_code + cloned._system_lang_code = self.system_lang_code + cloned._report_errors = self.report_errors + cloned._flood_sleep_threshold = self.flood_sleep_threshold + + def set_dc(self, dc_id, server_address, port): + self._dc_id = dc_id + self._server_address = server_address + self._port = port + + @property + def server_address(self): + return self._server_address + + @property + def port(self): + return self._port + + @property + def auth_key(self): + return self._auth_key + + @auth_key.setter + def auth_key(self, value): + self._auth_key = value + + @property + def time_offset(self): + return self._time_offset + + @time_offset.setter + def time_offset(self, value): + self._time_offset = value + + @property + def salt(self): + return self._salt + + @salt.setter + def salt(self, value): + self._salt = value + + @property + def device_model(self): + return self._device_model + + @property + def system_version(self): + return self._system_version + + @property + def app_version(self): + return self._app_version + + @property + def lang_code(self): + return self._lang_code + + @property + def system_lang_code(self): + return self._system_lang_code + + @property + def report_errors(self): + return self._report_errors + + @property + def sequence(self): + return self._sequence + + @property + def flood_sleep_threshold(self): + return self._flood_sleep_threshold + + def close(self): + pass + + def save(self): + pass + + def delete(self): + pass + + @classmethod + def list_sessions(cls): + raise NotImplementedError + + def get_new_msg_id(self): + """Generates a new unique message ID based on the current + time (in ms) since epoch""" + now = time.time() + self._time_offset + nanoseconds = int((now - int(now)) * 1e+9) + new_msg_id = (int(now) << 32) | (nanoseconds << 2) + + if self._last_msg_id >= new_msg_id: + new_msg_id = self._last_msg_id + 4 + + self._last_msg_id = new_msg_id + + return new_msg_id + + def update_time_offset(self, correct_msg_id): + now = int(time.time()) + correct = correct_msg_id >> 32 + self._time_offset = correct - now + self._last_msg_id = 0 + + def generate_sequence(self, content_related): + if content_related: + result = self._sequence * 2 + 1 + self._sequence += 1 + return result + else: + return self._sequence * 2 + + @staticmethod + def _entities_to_rows(tlo): + if not isinstance(tlo, TLObject) and utils.is_list_like(tlo): + # This may be a list of users already for instance + entities = tlo + else: + entities = [] + if hasattr(tlo, 'chats') and utils.is_list_like(tlo.chats): + entities.extend(tlo.chats) + if hasattr(tlo, 'users') and utils.is_list_like(tlo.users): + entities.extend(tlo.users) + if not entities: + return + + rows = [] # Rows to add (id, hash, username, phone, name) + for e in entities: + if not isinstance(e, TLObject): + continue + try: + p = utils.get_input_peer(e, allow_self=False) + marked_id = utils.get_peer_id(p) + except ValueError: + continue + + if isinstance(p, (InputPeerUser, InputPeerChannel)): + if not p.access_hash: + # Some users and channels seem to be returned without + # an 'access_hash', meaning Telegram doesn't want you + # to access them. This is the reason behind ensuring + # that the 'access_hash' is non-zero. See issue #354. + # Note that this checks for zero or None, see #392. + continue + else: + p_hash = p.access_hash + elif isinstance(p, InputPeerChat): + p_hash = 0 + else: + continue + + username = getattr(e, 'username', None) or None + if username is not None: + username = username.lower() + phone = getattr(e, 'phone', None) + name = utils.get_display_name(e) or None + rows.append((marked_id, p_hash, username, phone, name)) + return rows + + def process_entities(self, tlo): + self._entities += set(self._entities_to_rows(tlo)) + + def get_entity_rows_by_phone(self, phone): + rows = [(id, hash) for id, hash, _, found_phone, _ + in self._entities if found_phone == phone] + return rows[0] if rows else None + + def get_entity_rows_by_username(self, username): + rows = [(id, hash) for id, hash, found_username, _, _ + in self._entities if found_username == username] + return rows[0] if rows else None + + def get_entity_rows_by_name(self, name): + rows = [(id, hash) for id, hash, _, _, found_name + in self._entities if found_name == name] + return rows[0] if rows else None + + def get_entity_rows_by_id(self, id): + rows = [(id, hash) for found_id, hash, _, _, _ + in self._entities if found_id == id] + return rows[0] if rows else None + + def get_input_entity(self, key): + try: + if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd): + # hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel')) + # We already have an Input version, so nothing else required + return key + # Try to early return if this key can be casted as input peer + return utils.get_input_peer(key) + except (AttributeError, TypeError): + # Not a TLObject or can't be cast into InputPeer + if isinstance(key, TLObject): + key = utils.get_peer_id(key) + + result = None + if isinstance(key, str): + phone = utils.parse_phone(key) + if phone: + result = self.get_entity_rows_by_phone(phone) + else: + username, _ = utils.parse_username(key) + if username: + result = self.get_entity_rows_by_username(username) + + if isinstance(key, int): + result = self.get_entity_rows_by_id(key) + + if not result and isinstance(key, str): + result = self.get_entity_rows_by_name(key) + + if result: + i, h = result # unpack resulting tuple + i, k = utils.resolve_id(i) # removes the mark and returns kind + if k == PeerUser: + return InputPeerUser(i, h) + elif k == PeerChat: + return InputPeerChat(i) + elif k == PeerChannel: + return InputPeerChannel(i, h) + else: + raise ValueError('Could not find input entity with key ', key) + + def cache_file(self, md5_digest, file_size, instance): + if not isinstance(instance, (InputDocument, InputPhoto)): + raise TypeError('Cannot cache %s instance' % type(instance)) + key = (md5_digest, file_size, _SentFileType.from_type(instance)) + value = (instance.id, instance.access_hash) + self._files[key] = value + + def get_file(self, md5_digest, file_size, cls): + key = (md5_digest, file_size, _SentFileType.from_type(cls)) + try: + return self._files[key] + except KeyError: + return None diff --git a/telethon/session.py b/telethon/sessions/sqlite.py similarity index 59% rename from telethon/session.py rename to telethon/sessions/sqlite.py index 6b374c39..66a0c887 100644 --- a/telethon/session.py +++ b/telethon/sessions/sqlite.py @@ -5,14 +5,15 @@ import sqlite3 import struct import time from base64 import b64decode -from enum import Enum from os.path import isfile as file_exists from threading import Lock, RLock -from . import utils -from .crypto import AuthKey -from .tl import TLObject -from .tl.types import ( +from .. import utils +from .abstract import Session +from .memory import MemorySession, _SentFileType +from ..crypto import AuthKey +from ..tl import TLObject +from ..tl.types import ( PeerUser, PeerChat, PeerChannel, InputPeerUser, InputPeerChat, InputPeerChannel, InputPhoto, InputDocument @@ -22,21 +23,7 @@ EXTENSION = '.session' CURRENT_VERSION = 3 # database version -class _SentFileType(Enum): - DOCUMENT = 0 - PHOTO = 1 - - @staticmethod - def from_type(cls): - if cls == InputDocument: - return _SentFileType.DOCUMENT - elif cls == InputPhoto: - return _SentFileType.PHOTO - else: - raise ValueError('The cls must be either InputDocument/InputPhoto') - - -class Session: +class SQLiteSession(MemorySession): """This session contains the required information to login into your Telegram account. NEVER give the saved JSON file to anyone, since they would gain instant access to all your messages and contacts. @@ -44,7 +31,9 @@ class Session: If you think the session has been compromised, close all the sessions through an official Telegram client to revoke the authorization. """ + def __init__(self, session_id): + super().__init__() """session_user_id should either be a string or another Session. Note that if another session is given, only parameters like those required to init a connection will be copied. @@ -54,15 +43,15 @@ class Session: # For connection purposes if isinstance(session_id, Session): - self.device_model = session_id.device_model - self.system_version = session_id.system_version - self.app_version = session_id.app_version - self.lang_code = session_id.lang_code - self.system_lang_code = session_id.system_lang_code - self.lang_pack = session_id.lang_pack - self.report_errors = session_id.report_errors - self.save_entities = session_id.save_entities - self.flood_sleep_threshold = session_id.flood_sleep_threshold + self._device_model = session_id.device_model + self._system_version = session_id.system_version + self._app_version = session_id.app_version + self._lang_code = session_id.lang_code + self._system_lang_code = session_id.system_lang_code + self._report_errors = session_id.report_errors + self._flood_sleep_threshold = session_id.flood_sleep_threshold + if isinstance(session_id, SQLiteSession): + self.save_entities = session_id.save_entities else: # str / None if session_id: self.filename = session_id @@ -70,15 +59,14 @@ class Session: self.filename += EXTENSION system = platform.uname() - self.device_model = system.system or 'Unknown' - self.system_version = system.release or '1.0' - self.app_version = '1.0' # '0' will provoke error - self.lang_code = 'en' - self.system_lang_code = self.lang_code - self.lang_pack = '' - self.report_errors = True + self._device_model = system.system or 'Unknown' + self._system_version = system.release or '1.0' + self._app_version = '1.0' # '0' will provoke error + self._lang_code = 'en' + self._system_lang_code = self.lang_code + self._report_errors = True self.save_entities = True - self.flood_sleep_threshold = 60 + self._flood_sleep_threshold = 60 self.id = struct.unpack('q', os.urandom(8))[0] self._sequence = 0 @@ -163,6 +151,9 @@ class Session: c.close() self.save() + def clone(self): + return SQLiteSession(self) + def _check_migrate_json(self): if file_exists(self.filename): try: @@ -233,19 +224,7 @@ class Session: self._auth_key = None c.close() - @property - def server_address(self): - return self._server_address - - @property - def port(self): - return self._port - - @property - def auth_key(self): - return self._auth_key - - @auth_key.setter + @Session.auth_key.setter def auth_key(self, value): self._auth_key = value self._update_session_table() @@ -298,53 +277,14 @@ class Session: except OSError: return False - @staticmethod - def list_sessions(): + @classmethod + def list_sessions(cls): """Lists all the sessions of the users who have ever connected using this client and never logged out """ return [os.path.splitext(os.path.basename(f))[0] for f in os.listdir('.') if f.endswith(EXTENSION)] - def generate_sequence(self, content_related): - """Thread safe method to generates the next sequence number, - based on whether it was confirmed yet or not. - - Note that if confirmed=True, the sequence number - will be increased by one too - """ - with self._seq_no_lock: - if content_related: - result = self._sequence * 2 + 1 - self._sequence += 1 - return result - else: - return self._sequence * 2 - - def get_new_msg_id(self): - """Generates a new unique message ID based on the current - time (in ms) since epoch""" - # Refer to mtproto_plain_sender.py for the original method - now = time.time() + self.time_offset - nanoseconds = int((now - int(now)) * 1e+9) - # "message identifiers are divisible by 4" - new_msg_id = (int(now) << 32) | (nanoseconds << 2) - - with self._msg_id_lock: - if self._last_msg_id >= new_msg_id: - new_msg_id = self._last_msg_id + 4 - - self._last_msg_id = new_msg_id - - return new_msg_id - - def update_time_offset(self, correct_msg_id): - """Updates the time offset based on a known correct message ID""" - now = int(time.time()) - correct = correct_msg_id >> 32 - self.time_offset = correct - now - self._last_msg_id = 0 - # Entity processing def process_entities(self, tlo): @@ -356,49 +296,7 @@ class Session: if not self.save_entities: return - if not isinstance(tlo, TLObject) and utils.is_list_like(tlo): - # This may be a list of users already for instance - entities = tlo - else: - entities = [] - if hasattr(tlo, 'chats') and utils.is_list_like(tlo.chats): - entities.extend(tlo.chats) - if hasattr(tlo, 'users') and utils.is_list_like(tlo.users): - entities.extend(tlo.users) - if not entities: - return - - rows = [] # Rows to add (id, hash, username, phone, name) - for e in entities: - if not isinstance(e, TLObject): - continue - try: - p = utils.get_input_peer(e, allow_self=False) - marked_id = utils.get_peer_id(p) - except ValueError: - continue - - if isinstance(p, (InputPeerUser, InputPeerChannel)): - if not p.access_hash: - # Some users and channels seem to be returned without - # an 'access_hash', meaning Telegram doesn't want you - # to access them. This is the reason behind ensuring - # that the 'access_hash' is non-zero. See issue #354. - # Note that this checks for zero or None, see #392. - continue - else: - p_hash = p.access_hash - elif isinstance(p, InputPeerChat): - p_hash = 0 - else: - continue - - username = getattr(e, 'username', None) or None - if username is not None: - username = username.lower() - phone = getattr(e, 'phone', None) - name = utils.get_display_name(e) or None - rows.append((marked_id, p_hash, username, phone, name)) + rows = self._entities_to_rows(tlo) if not rows: return @@ -408,62 +306,26 @@ class Session: ) self.save() - def get_input_entity(self, key): - """Parses the given string, integer or TLObject key into a - marked entity ID, which is then used to fetch the hash - from the database. - - If a callable key is given, every row will be fetched, - and passed as a tuple to a function, that should return - a true-like value when the desired row is found. - - Raises ValueError if it cannot be found. - """ - try: - if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd): - # hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel')) - # We already have an Input version, so nothing else required - return key - # Try to early return if this key can be casted as input peer - return utils.get_input_peer(key) - except (AttributeError, TypeError): - # Not a TLObject or can't be cast into InputPeer - if isinstance(key, TLObject): - key = utils.get_peer_id(key) - + def _fetchone_entity(self, query, args): c = self._cursor() - if isinstance(key, str): - phone = utils.parse_phone(key) - if phone: - c.execute('select id, hash from entities where phone=?', - (phone,)) - else: - username, _ = utils.parse_username(key) - if username: - c.execute('select id, hash from entities where username=?', + c.execute(query, args) + return c.fetchone() + + def get_entity_rows_by_phone(self, phone): + return self._fetchone_entity( + 'select id, hash from entities where phone=?', (phone,)) + + def get_entity_rows_by_username(self, username): + self._fetchone_entity('select id, hash from entities where username=?', (username,)) - if isinstance(key, int): - c.execute('select id, hash from entities where id=?', (key,)) + def get_entity_rows_by_name(self, name): + self._fetchone_entity('select id, hash from entities where name=?', + (name,)) - result = c.fetchone() - if not result and isinstance(key, str): - # Try exact match by name if phone/username failed - c.execute('select id, hash from entities where name=?', (key,)) - result = c.fetchone() - - c.close() - if result: - i, h = result # unpack resulting tuple - i, k = utils.resolve_id(i) # removes the mark and returns kind - if k == PeerUser: - return InputPeerUser(i, h) - elif k == PeerChat: - return InputPeerChat(i) - elif k == PeerChannel: - return InputPeerChannel(i, h) - else: - raise ValueError('Could not find input entity with key ', key) + def get_entity_rows_by_id(self, id): + self._fetchone_entity('select id, hash from entities where id=?', + (id,)) # File processing diff --git a/telethon/telegram_bare_client.py b/telethon/telegram_bare_client.py index 8a15476e..3a5b2bd0 100644 --- a/telethon/telegram_bare_client.py +++ b/telethon/telegram_bare_client.py @@ -14,7 +14,7 @@ from .errors import ( PhoneMigrateError, NetworkMigrateError, UserMigrateError ) from .network import authenticator, MtProtoSender, Connection, ConnectionMode -from .session import Session +from .sessions import Session, SQLiteSession from .tl import TLObject from .tl.all_tlobjects import LAYER from .tl.functions import ( @@ -81,10 +81,10 @@ class TelegramBareClient: "Refer to telethon.rtfd.io for more information.") self._use_ipv6 = use_ipv6 - + # Determine what session object we have if isinstance(session, str) or session is None: - session = Session(session) + session = SQLiteSession(session) elif not isinstance(session, Session): raise TypeError( 'The given session must be a str or a Session instance.' @@ -361,7 +361,7 @@ class TelegramBareClient: # # Construct this session with the connection parameters # (system version, device model...) from the current one. - session = Session(self.session) + session = self.session.clone() session.set_dc(dc.id, dc.ip_address, dc.port) self._exported_sessions[dc_id] = session @@ -387,7 +387,7 @@ class TelegramBareClient: session = self._exported_sessions.get(cdn_redirect.dc_id) if not session: dc = self._get_dc(cdn_redirect.dc_id, cdn=True) - session = Session(self.session) + session = self.session.clone() session.set_dc(dc.id, dc.ip_address, dc.port) self._exported_sessions[cdn_redirect.dc_id] = session