Remove custom EntityDatabase and use sqlite3 instead

There are still a few things to change, like cleaning up the
code and actually caching the entities as a whole (currently,
although the username/phone/name can be used to fetch their
input version which is an improvement, their full version
needs to be re-fetched. Maybe it's a good thing though?)
This commit is contained in:
Lonami Exo
2017-12-27 00:50:09 +01:00
parent 0a4849b150
commit aef96f1b68
4 changed files with 181 additions and 303 deletions

View File

@@ -1,252 +0,0 @@
import re
from threading import Lock
from ..tl import TLObject
from ..tl.types import (
User, Chat, Channel, PeerUser, PeerChat, PeerChannel,
InputPeerUser, InputPeerChat, InputPeerChannel
)
from .. import utils # Keep this line the last to maybe fix #357
USERNAME_RE = re.compile(
r'@|(?:https?://)?(?:telegram\.(?:me|dog)|t\.me)/(joinchat/)?'
)
class EntityDatabase:
def __init__(self, input_list=None, enabled=True, enabled_full=True):
"""Creates a new entity database with an initial load of "Input"
entities, if any.
If 'enabled', input entities will be saved. The whole entity
will be saved if both 'enabled' and 'enabled_full' are True.
"""
self.enabled = enabled
self.enabled_full = enabled_full
self._lock = Lock()
self._entities = {} # marked_id: user|chat|channel
if input_list:
# TODO For compatibility reasons some sessions were saved with
# 'access_hash': null in the JSON session file. Drop these, as
# it means we don't have access to such InputPeers. Issue #354.
self._input_entities = {
k: v for k, v in input_list if v is not None
}
else:
self._input_entities = {} # marked_id: hash
# TODO Allow disabling some extra mappings
self._username_id = {} # username: marked_id
self._phone_id = {} # phone: marked_id
def process(self, tlobject):
"""Processes all the found entities on the given TLObject,
unless .enabled is False.
Returns True if new input entities were added.
"""
if not self.enabled:
return False
# Save all input entities we know of
if not isinstance(tlobject, TLObject) and hasattr(tlobject, '__iter__'):
# This may be a list of users already for instance
return self.expand(tlobject)
entities = []
if hasattr(tlobject, 'chats') and hasattr(tlobject.chats, '__iter__'):
entities.extend(tlobject.chats)
if hasattr(tlobject, 'users') and hasattr(tlobject.users, '__iter__'):
entities.extend(tlobject.users)
return self.expand(entities)
def expand(self, entities):
"""Adds new input entities to the local database unconditionally.
Unknown types will be ignored.
"""
if not entities or not self.enabled:
return False
new = [] # Array of entities (User, Chat, or Channel)
new_input = {} # Dictionary of {entity_marked_id: access_hash}
for e in entities:
if not isinstance(e, TLObject):
continue
try:
p = utils.get_input_peer(e, allow_self=False)
marked_id = utils.get_peer_id(p, add_mark=True)
has_hash = False
if isinstance(p, InputPeerChat):
# Chats don't have a hash
new_input[marked_id] = 0
has_hash = True
elif p.access_hash:
# Some users and channels seem to be returned without
# an 'access_hash', meaning Telegram doesn't want you
# to access them. This is the reason behind ensuring
# that the 'access_hash' is non-zero. See issue #354.
new_input[marked_id] = p.access_hash
has_hash = True
if self.enabled_full and has_hash:
if isinstance(e, (User, Chat, Channel)):
new.append(e)
except ValueError:
pass
with self._lock:
before = len(self._input_entities)
self._input_entities.update(new_input)
for e in new:
self._add_full_entity(e)
return len(self._input_entities) != before
def _add_full_entity(self, entity):
"""Adds a "full" entity (User, Chat or Channel, not "Input*"),
despite the value of self.enabled and self.enabled_full.
Not to be confused with UserFull, ChatFull, or ChannelFull,
"full" means simply not "Input*".
"""
marked_id = utils.get_peer_id(
utils.get_input_peer(entity, allow_self=False), add_mark=True
)
try:
old_entity = self._entities[marked_id]
old_entity.__dict__.update(entity.__dict__) # Keep old references
# Update must delete old username and phone
username = getattr(old_entity, 'username', None)
if username:
del self._username_id[username.lower()]
phone = getattr(old_entity, 'phone', None)
if phone:
del self._phone_id[phone]
except KeyError:
# Add new entity
self._entities[marked_id] = entity
# Always update username or phone if any
username = getattr(entity, 'username', None)
if username:
self._username_id[username.lower()] = marked_id
phone = getattr(entity, 'phone', None)
if phone:
self._phone_id[phone] = marked_id
def _parse_key(self, key):
"""Parses the given string, integer or TLObject key into a
marked user ID ready for use on self._entities.
If a callable key is given, the entity will be passed to the
function, and if it returns a true-like value, the marked ID
for such entity will be returned.
Raises ValueError if it cannot be parsed.
"""
if isinstance(key, str):
phone = EntityDatabase.parse_phone(key)
try:
if phone:
return self._phone_id[phone]
else:
username, _ = EntityDatabase.parse_username(key)
return self._username_id[username.lower()]
except KeyError as e:
raise ValueError() from e
if isinstance(key, int):
return key # normal IDs are assumed users
if isinstance(key, TLObject):
return utils.get_peer_id(key, add_mark=True)
if callable(key):
for k, v in self._entities.items():
if key(v):
return k
raise ValueError()
def __getitem__(self, key):
"""See the ._parse_key() docstring for possible values of the key"""
try:
return self._entities[self._parse_key(key)]
except (ValueError, KeyError) as e:
raise KeyError(key) from e
def __delitem__(self, key):
try:
old = self._entities.pop(self._parse_key(key))
# Try removing the username and phone (if pop didn't fail),
# since the entity may have no username or phone, just ignore
# errors. It should be there if we popped the entity correctly.
try:
del self._username_id[getattr(old, 'username', None)]
except KeyError:
pass
try:
del self._phone_id[getattr(old, 'phone', None)]
except KeyError:
pass
except (ValueError, KeyError) as e:
raise KeyError(key) from e
@staticmethod
def parse_phone(phone):
"""Parses the given phone, or returns None if it's invalid"""
if isinstance(phone, int):
return str(phone)
else:
phone = re.sub(r'[+()\s-]', '', str(phone))
if phone.isdigit():
return phone
@staticmethod
def parse_username(username):
"""Parses the given username or channel access hash, given
a string, username or URL. Returns a tuple consisting of
both the stripped username and whether it is a joinchat/ hash.
"""
username = username.strip()
m = USERNAME_RE.match(username)
if m:
return username[m.end():], bool(m.group(1))
else:
return username, False
def get_input_entity(self, peer):
try:
i = utils.get_peer_id(peer, add_mark=True)
h = self._input_entities[i] # we store the IDs marked
i, k = utils.resolve_id(i) # removes the mark and returns kind
if k == PeerUser:
return InputPeerUser(i, h)
elif k == PeerChat:
return InputPeerChat(i)
elif k == PeerChannel:
return InputPeerChannel(i, h)
except ValueError as e:
raise KeyError(peer) from e
raise KeyError(peer)
def get_input_list(self):
return list(self._input_entities.items())
def clear(self, target=None):
if target is None:
self._entities.clear()
else:
del self[target]

View File

@@ -8,8 +8,12 @@ from base64 import b64decode
from os.path import isfile as file_exists
from threading import Lock
from .entity_database import EntityDatabase
from .. import helpers
from .. import utils, helpers
from ..tl import TLObject
from ..tl.types import (
PeerUser, PeerChat, PeerChannel,
InputPeerUser, InputPeerChat, InputPeerChannel
)
EXTENSION = '.session'
CURRENT_VERSION = 1 # database version
@@ -75,10 +79,9 @@ class Session:
self._auth_key = None
self._layer = 0
self._salt = 0 # Signed long
self.entities = EntityDatabase() # Known and cached entities
# Migrating from .json -> SQL
self._check_migrate_json()
entities = self._check_migrate_json()
self._conn = sqlite3.connect(self.filename, check_same_thread=False)
c = self._conn.cursor()
@@ -114,14 +117,20 @@ class Session:
)
c.execute(
"""create table entities (
id integer,
hash integer,
id integer primary key,
hash integer not null,
username text,
phone integer,
name text
)"""
)
c.execute("insert into version values (1)")
# Migrating from JSON -> new table and may have entities
if entities:
c.executemany(
'insert or replace into entities values (?,?,?,?,?)',
entities
)
c.close()
self.save()
@@ -130,6 +139,8 @@ class Session:
try:
with open(self.filename, encoding='utf-8') as f:
data = json.load(f)
self.delete() # Delete JSON file to create database
self._port = data.get('port', self._port)
self._salt = data.get('salt', self._salt)
# Keep while migrating from unsigned to signed salt
@@ -146,10 +157,12 @@ class Session:
key = b64decode(data['auth_key_data'])
self._auth_key = AuthKey(data=key)
self.entities = EntityDatabase(data.get('entities', []))
self.delete() # Delete JSON file to create database
rows = []
for p_id, p_hash in data.get('entities', []):
rows.append((p_id, p_hash, None, None, None))
return rows
except (UnicodeDecodeError, json.decoder.JSONDecodeError):
pass
return [] # No entities
def _upgrade_database(self, old):
pass
@@ -275,9 +288,103 @@ class Session:
correct = correct_msg_id >> 32
self.time_offset = correct - now
def process_entities(self, tlobject):
try:
if self.entities.process(tlobject):
self.save() # Save if any new entities got added
except:
pass
# Entity processing
def process_entities(self, tlo):
"""Processes all the found entities on the given TLObject,
unless .enabled is False.
Returns True if new input entities were added.
"""
if not self.save_entities:
return
if not isinstance(tlo, TLObject) and hasattr(tlo, '__iter__'):
# This may be a list of users already for instance
entities = tlo
else:
entities = []
if hasattr(tlo, 'chats') and hasattr(tlo.chats, '__iter__'):
entities.extend(tlo.chats)
if hasattr(tlo, 'users') and hasattr(tlo.users, '__iter__'):
entities.extend(tlo.users)
if not entities:
return
rows = [] # Rows to add (id, hash, username, phone, name)
for e in entities:
if not isinstance(e, TLObject):
continue
try:
p = utils.get_input_peer(e, allow_self=False)
marked_id = utils.get_peer_id(p, add_mark=True)
p_hash = None
if isinstance(p, InputPeerChat):
p_hash = 0
elif p.access_hash:
# Some users and channels seem to be returned without
# an 'access_hash', meaning Telegram doesn't want you
# to access them. This is the reason behind ensuring
# that the 'access_hash' is non-zero. See issue #354.
p_hash = p.access_hash
if p_hash is not None:
username = getattr(e, 'username', None)
phone = getattr(e, 'phone', None)
name = utils.get_display_name(e) or None
rows.append((marked_id, p_hash, username, phone, name))
except ValueError:
pass
if not rows:
return
with self._db_lock:
self._conn.executemany(
'insert or replace into entities values (?,?,?,?,?)', rows
)
self.save()
def get_input_entity(self, key):
"""Parses the given string, integer or TLObject key into a
marked entity ID, which is then used to fetch the hash
from the database.
If a callable key is given, every row will be fetched,
and passed as a tuple to a function, that should return
a true-like value when the desired row is found.
Raises ValueError if it cannot be found.
"""
c = self._conn.cursor()
if isinstance(key, str):
phone = utils.parse_phone(key)
if phone:
c.execute('select id, hash from entities where phone=?',
(phone,))
else:
username, _ = utils.parse_username(key)
c.execute('select id, hash from entities where username=?',
(username,))
if isinstance(key, TLObject):
# crc32(b'InputPeer') and crc32(b'Peer')
if type(key).SUBCLASS_OF_ID == 0xc91c90b6:
return key
key = utils.get_peer_id(key, add_mark=True)
if isinstance(key, int):
c.execute('select id, hash from entities where id=?', (key,))
result = c.fetchone()
if result:
i, h = result # unpack resulting tuple
i, k = utils.resolve_id(i) # removes the mark and returns kind
if k == PeerUser:
return InputPeerUser(i, h)
elif k == PeerChat:
return InputPeerChat(i)
elif k == PeerChannel:
return InputPeerChannel(i, h)
else:
raise ValueError('Could not find input entity with key ', key)