Use EntityDatabase in the Session class

This commit is contained in:
Lonami Exo
2017-10-04 21:02:45 +02:00
parent 5be9df0eec
commit a0fc5ed54e
3 changed files with 146 additions and 139 deletions

View File

@@ -6,11 +6,8 @@ from base64 import b64encode, b64decode
from os.path import isfile as file_exists
from threading import Lock
from .. import helpers, utils
from ..tl.types import (
InputPeerUser, InputPeerChat, InputPeerChannel,
PeerUser, PeerChat, PeerChannel
)
from .entity_database import EntityDatabase
from .. import helpers
class Session:
@@ -70,8 +67,7 @@ class Session:
self.auth_key = None
self.layer = 0
self.salt = 0 # Unsigned long
self._input_entities = {} # {marked_id: hash}
self._entities_lock = Lock()
self.entities = EntityDatabase() # Known and cached entities
def save(self):
"""Saves the current session object as session_user_id.session"""
@@ -90,7 +86,7 @@ class Session:
if self.auth_key else None
}
if self.save_entities:
out_dict['entities'] = list(self._input_entities.items())
out_dict['entities'] = self.entities.get_input_list()
json.dump(out_dict, file)
@@ -139,8 +135,7 @@ class Session:
key = b64decode(data['auth_key_data'])
result.auth_key = AuthKey(data=key)
for e_mid, e_hash in data.get('entities', []):
result._input_entities[e_mid] = e_hash
result.entities = EntityDatabase(data.get('entities', []))
except (json.decoder.JSONDecodeError, UnicodeDecodeError):
pass
@@ -186,58 +181,5 @@ class Session:
self.time_offset = correct - now
def process_entities(self, tlobject):
"""Processes all the found entities on the given TLObject,
unless .save_entities is False, and saves the session file.
"""
if not self.save_entities:
return
# Save all input entities we know of
entities = []
if hasattr(tlobject, 'chats') and hasattr(tlobject.chats, '__iter__'):
entities.extend(tlobject.chats)
if hasattr(tlobject, 'users') and hasattr(tlobject.users, '__iter__'):
entities.extend(tlobject.users)
if self.add_entities(entities):
if self.entities.process(tlobject):
self.save() # Save if any new entities got added
def add_entities(self, entities):
"""Adds new input entities to the local database unconditionally.
Unknown types will be ignored.
"""
if not entities:
return False
new = {}
for e in entities:
try:
p = utils.get_input_peer(e)
new[utils.get_peer_id(p, add_mark=True)] = \
getattr(p, 'access_hash', 0) # chats won't have hash
except ValueError:
pass
with self._entities_lock:
before = len(self._input_entities)
self._input_entities.update(new)
return len(self._input_entities) != before
def get_input_entity(self, peer):
"""Gets an input entity known its Peer or a marked ID,
or raises KeyError if not found/invalid.
"""
if not isinstance(peer, int):
peer = utils.get_peer_id(peer, add_mark=True)
entity_hash = self._input_entities[peer]
entity_id, peer_class = utils.resolve_id(peer)
if peer_class == PeerUser:
return InputPeerUser(entity_id, entity_hash)
if peer_class == PeerChat:
return InputPeerChat(entity_id)
if peer_class == PeerChannel:
return InputPeerChannel(entity_id, entity_hash)
raise KeyError()