diff --git a/telethon/entity_database.py b/telethon/entity_database.py deleted file mode 100644 index 1e96fe51..00000000 --- a/telethon/entity_database.py +++ /dev/null @@ -1,75 +0,0 @@ -from . import utils -from .tl import TLObject - - -class EntityDatabase: - def __init__(self, enabled=True): - self.enabled = enabled - - self._entities = {} # marked_id: user|chat|channel - - # TODO Allow disabling some extra mappings - self._username_id = {} # username: marked_id - - def add(self, entity): - if not self.enabled: - return - - # Adds or updates the given entity - marked_id = utils.get_peer_id(entity, add_mark=True) - try: - old_entity = self._entities[marked_id] - old_entity.__dict__.update(entity) # Keep old references - - # Update must delete old username - username = getattr(old_entity, 'username', None) - if username: - del self._username_id[username.lower()] - except KeyError: - # Add new entity - self._entities[marked_id] = entity - - # Always update username if any - username = getattr(entity, 'username', None) - if username: - self._username_id[username.lower()] = marked_id - - def __getitem__(self, key): - """Accepts a digit only string as phone number, - otherwise it's treated as an username. - - If an integer is given, it's treated as the ID of the desired User. - The ID given won't try to be guessed as the ID of a chat or channel, - as there may be an user with that ID, and it would be unreliable. - - If a Peer is given (PeerUser, PeerChat, PeerChannel), - its specific entity is retrieved as User, Chat or Channel. - Note that megagroups are channels with .megagroup = True. - """ - if isinstance(key, str): - # TODO Parse phone properly, currently only usernames - key = key.lstrip('@').lower() - # TODO Use the client to return from username if not found - return self._entities[self._username_id[key]] - - if isinstance(key, int): - return self._entities[key] # normal IDs are assumed users - - if isinstance(key, TLObject) and type(key).SUBCLASS_OF_ID == 0x2d45687: - return self._entities[utils.get_peer_id(key, add_mark=True)] - - raise KeyError(key) - - def __delitem__(self, key): - target = self[key] - del self._entities[key] - if getattr(target, 'username'): - del self._username_id[target.username] - - # TODO Allow search by name by tokenizing the input and return a list - - def clear(self, target=None): - if target is None: - self._entities.clear() - else: - del self[target] diff --git a/telethon/tl/entity_database.py b/telethon/tl/entity_database.py new file mode 100644 index 00000000..6a1d7dbb --- /dev/null +++ b/telethon/tl/entity_database.py @@ -0,0 +1,140 @@ +from threading import Lock + +from .. import utils +from ..tl import TLObject +from ..tl.types import User, Chat, Channel + + +class EntityDatabase: + def __init__(self, input_list=None, enabled=True): + self.enabled = enabled + + self._lock = Lock() + self._entities = {} # marked_id: user|chat|channel + + if input_list: + self._input_entities = {k: v for k, v in input_list} + else: + self._input_entities = {} # marked_id: hash + + # TODO Allow disabling some extra mappings + self._username_id = {} # username: marked_id + + def process(self, tlobject): + """Processes all the found entities on the given TLObject, + unless .enabled is False. + + Returns True if new input entities were added. + """ + if not self.enabled: + return False + + # Save all input entities we know of + entities = [] + if hasattr(tlobject, 'chats') and hasattr(tlobject.chats, '__iter__'): + entities.extend(tlobject.chats) + if hasattr(tlobject, 'users') and hasattr(tlobject.users, '__iter__'): + entities.extend(tlobject.users) + + return self.expand(entities) + + def expand(self, entities): + """Adds new input entities to the local database unconditionally. + Unknown types will be ignored. + """ + if not entities or not self.enabled: + return False + + new = [] # Array of entities (User, Chat, or Channel) + new_input = {} # Dictionary of {entity_marked_id: access_hash} + for e in entities: + if not isinstance(e, TLObject): + continue + + try: + p = utils.get_input_peer(e) + new_input[utils.get_peer_id(p, add_mark=True)] = \ + getattr(p, 'access_hash', 0) # chats won't have hash + + if isinstance(e, User) \ + or isinstance(e, Chat) \ + or isinstance(e, Channel): + new.append(e) + except ValueError: + pass + + with self._lock: + before = len(self._input_entities) + self._input_entities.update(new_input) + for e in new: + self._add_full_entity(e) + return len(self._input_entities) != before + + def _add_full_entity(self, entity): + """Adds a "full" entity (User, Chat or Channel, not "Input*"). + + Not to be confused with UserFull, ChatFull, or ChannelFull, + "full" means simply not "Input*". + """ + marked_id = utils.get_peer_id( + utils.get_input_peer(entity), add_mark=True + ) + try: + old_entity = self._entities[marked_id] + old_entity.__dict__.update(entity.__dict__) # Keep old references + + # Update must delete old username + username = getattr(old_entity, 'username', None) + if username: + del self._username_id[username.lower()] + except KeyError: + # Add new entity + self._entities[marked_id] = entity + + # Always update username if any + username = getattr(entity, 'username', None) + if username: + self._username_id[username.lower()] = marked_id + + def __getitem__(self, key): + """Accepts a digit only string as phone number, + otherwise it's treated as an username. + + If an integer is given, it's treated as the ID of the desired User. + The ID given won't try to be guessed as the ID of a chat or channel, + as there may be an user with that ID, and it would be unreliable. + + If a Peer is given (PeerUser, PeerChat, PeerChannel), + its specific entity is retrieved as User, Chat or Channel. + Note that megagroups are channels with .megagroup = True. + """ + if isinstance(key, str): + # TODO Parse phone properly, currently only usernames + key = key.lstrip('@').lower() + # TODO Use the client to return from username if not found + return self._entities[self._username_id[key]] + + if isinstance(key, int): + return self._entities[key] # normal IDs are assumed users + + if isinstance(key, TLObject) and type(key).SUBCLASS_OF_ID == 0x2d45687: + return self._entities[utils.get_peer_id(key, add_mark=True)] + + raise KeyError(key) + + def __delitem__(self, key): + target = self[key] + del self._entities[key] + if getattr(target, 'username'): + del self._username_id[target.username] + + # TODO Allow search by name by tokenizing the input and return a list + + def get_input_list(self): + return list(self._input_entities.items()) + + def clear(self, target=None): + if target is None: + self._entities.clear() + else: + del self[target] diff --git a/telethon/tl/session.py b/telethon/tl/session.py index d3854c8d..2b691ad7 100644 --- a/telethon/tl/session.py +++ b/telethon/tl/session.py @@ -6,11 +6,8 @@ from base64 import b64encode, b64decode from os.path import isfile as file_exists from threading import Lock -from .. import helpers, utils -from ..tl.types import ( - InputPeerUser, InputPeerChat, InputPeerChannel, - PeerUser, PeerChat, PeerChannel -) +from .entity_database import EntityDatabase +from .. import helpers class Session: @@ -70,8 +67,7 @@ class Session: self.auth_key = None self.layer = 0 self.salt = 0 # Unsigned long - self._input_entities = {} # {marked_id: hash} - self._entities_lock = Lock() + self.entities = EntityDatabase() # Known and cached entities def save(self): """Saves the current session object as session_user_id.session""" @@ -90,7 +86,7 @@ class Session: if self.auth_key else None } if self.save_entities: - out_dict['entities'] = list(self._input_entities.items()) + out_dict['entities'] = self.entities.get_input_list() json.dump(out_dict, file) @@ -139,8 +135,7 @@ class Session: key = b64decode(data['auth_key_data']) result.auth_key = AuthKey(data=key) - for e_mid, e_hash in data.get('entities', []): - result._input_entities[e_mid] = e_hash + result.entities = EntityDatabase(data.get('entities', [])) except (json.decoder.JSONDecodeError, UnicodeDecodeError): pass @@ -186,58 +181,5 @@ class Session: self.time_offset = correct - now def process_entities(self, tlobject): - """Processes all the found entities on the given TLObject, - unless .save_entities is False, and saves the session file. - """ - if not self.save_entities: - return - - # Save all input entities we know of - entities = [] - if hasattr(tlobject, 'chats') and hasattr(tlobject.chats, '__iter__'): - entities.extend(tlobject.chats) - if hasattr(tlobject, 'users') and hasattr(tlobject.users, '__iter__'): - entities.extend(tlobject.users) - - if self.add_entities(entities): + if self.entities.process(tlobject): self.save() # Save if any new entities got added - - def add_entities(self, entities): - """Adds new input entities to the local database unconditionally. - Unknown types will be ignored. - """ - if not entities: - return False - - new = {} - for e in entities: - try: - p = utils.get_input_peer(e) - new[utils.get_peer_id(p, add_mark=True)] = \ - getattr(p, 'access_hash', 0) # chats won't have hash - except ValueError: - pass - - with self._entities_lock: - before = len(self._input_entities) - self._input_entities.update(new) - return len(self._input_entities) != before - - def get_input_entity(self, peer): - """Gets an input entity known its Peer or a marked ID, - or raises KeyError if not found/invalid. - """ - if not isinstance(peer, int): - peer = utils.get_peer_id(peer, add_mark=True) - - entity_hash = self._input_entities[peer] - entity_id, peer_class = utils.resolve_id(peer) - - if peer_class == PeerUser: - return InputPeerUser(entity_id, entity_hash) - if peer_class == PeerChat: - return InputPeerChat(entity_id) - if peer_class == PeerChannel: - return InputPeerChannel(entity_id, entity_hash) - - raise KeyError()