Fix remaining upgraded uses of the session to work correctly

This commit is contained in:
Lonami Exo
2021-09-19 17:08:51 +02:00
parent d33402f02e
commit 9479e215fb
5 changed files with 28 additions and 23 deletions

View File

@@ -59,7 +59,7 @@ class Session(ABC):
raise NotImplementedError
@abstractmethod
async def get_entity(self, ty: int, id: int) -> Optional[Entity]:
async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
"""
Get the `Entity` with matching ``ty`` and ``id``.
@@ -75,6 +75,8 @@ class Session(ABC):
the corresponding ``access_hash`` should still be returned.
You may use `types.canonical_entity_type` to find out the canonical type.
A ``ty`` with the value of ``None`` should be treated as "any entity with matching ID".
"""
raise NotImplementedError

View File

@@ -36,7 +36,7 @@ class MemorySession(Session):
async def insert_entities(self, entities: List[Entity]):
self.entities.update((e.id, (e.ty, e.access_hash)) for e in entities)
async def get_entity(self, ty: int, id: int) -> Optional[Entity]:
async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
try:
ty, access_hash = self.entities[id]
return Entity(ty, id, access_hash)

View File

@@ -55,13 +55,17 @@ class SQLiteSession(Session):
self._upgrade_database(old=version)
c.execute("delete from version")
c.execute("insert into version values (?)", (CURRENT_VERSION,))
self.save()
self._conn.commit()
else:
# Tables don't exist, create new ones
self._create_table(c, 'version (version integer primary key)')
self._mk_tables(c)
c.execute("insert into version values (?)", (CURRENT_VERSION,))
c.close()
self.save()
self._conn.commit()
# Must have committed or else the version will not have been updated while new tables
# exist, leading to a half-upgraded state.
c.close()
def _upgrade_database(self, old):
c = self._cursor()
@@ -146,9 +150,6 @@ class SQLiteSession(Session):
def _mk_tables(self, c):
self._create_table(
c,
'''version (
version integer primary key
)''',
'''datacenter (
id integer primary key,
ip text not null,
@@ -243,7 +244,7 @@ class SQLiteSession(Session):
finally:
c.close()
async def get_entity(self, ty: int, id: int) -> Optional[Entity]:
async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
row = self._execute('select ty, id, access_hash from entity where id = ?', id)
return Entity(*row) if row else None