Improve TakeoutClient proxy and takeout functionality (#1106)

This commit is contained in:
Dmitry D. Chernov
2019-02-10 20:10:41 +10:00
committed by Lonami
parent 274fa72a8c
commit 9a98d41a2c
6 changed files with 153 additions and 51 deletions

View File

@@ -53,6 +53,23 @@ class Session(ABC):
"""
raise NotImplementedError
@property
@abstractmethod
def takeout_id(self):
"""
Returns an ID of the takeout process initialized for this session,
or ``None`` if there's no were any unfinished takeout requests.
"""
raise NotImplementedError
@takeout_id.setter
@abstractmethod
def takeout_id(self, value):
"""
Sets the ID of the unfinished takeout process for this session.
"""
raise NotImplementedError
@abstractmethod
def get_update_state(self, entity_id):
"""

View File

@@ -32,6 +32,7 @@ class MemorySession(Session):
self._server_address = None
self._port = None
self._auth_key = None
self._takeout_id = None
self._files = {}
self._entities = set()
@@ -62,6 +63,14 @@ class MemorySession(Session):
def auth_key(self, value):
self._auth_key = value
@property
def takeout_id(self):
return self._takeout_id
@takeout_id.setter
def takeout_id(self, value):
self._takeout_id = value
def get_update_state(self, entity_id):
return self._update_states.get(entity_id, None)

View File

@@ -18,7 +18,7 @@ except ImportError:
sqlite3 = None
EXTENSION = '.session'
CURRENT_VERSION = 4 # database version
CURRENT_VERSION = 5 # database version
class SQLiteSession(MemorySession):
@@ -65,7 +65,8 @@ class SQLiteSession(MemorySession):
c.execute('select * from sessions')
tuple_ = c.fetchone()
if tuple_:
self._dc_id, self._server_address, self._port, key, = tuple_
self._dc_id, self._server_address, self._port, key, \
self._takeout_id = tuple_
self._auth_key = AuthKey(data=key)
c.close()
@@ -79,7 +80,8 @@ class SQLiteSession(MemorySession):
dc_id integer primary key,
server_address text,
port integer,
auth_key blob
auth_key blob,
takeout_id integer
)"""
,
"""entities (
@@ -172,6 +174,9 @@ class SQLiteSession(MemorySession):
date integer,
seq integer
)""")
if old == 4:
old += 1
c.execute("alter table sessions add column takeout_id integer")
c.close()
@staticmethod
@@ -197,6 +202,11 @@ class SQLiteSession(MemorySession):
self._auth_key = value
self._update_session_table()
@MemorySession.takeout_id.setter
def takeout_id(self, value):
self._takeout_id = value
self._update_session_table()
def _update_session_table(self):
c = self._cursor()
# While we can save multiple rows into the sessions table
@@ -205,11 +215,12 @@ class SQLiteSession(MemorySession):
# some more work before being able to save auth_key's for
# multiple DCs. Probably done differently.
c.execute('delete from sessions')
c.execute('insert or replace into sessions values (?,?,?,?)', (
c.execute('insert or replace into sessions values (?,?,?,?,?)', (
self._dc_id,
self._server_address,
self._port,
self._auth_key.key if self._auth_key else b''
self._auth_key.key if self._auth_key else b'',
self._takeout_id
))
c.close()

View File

@@ -5,12 +5,16 @@ import struct
from .memory import MemorySession
from ..crypto import AuthKey
_STRUCT_PREFORMAT = '>B{}sH256s'
CURRENT_VERSION = '1'
class StringSession(MemorySession):
"""
This minimal session file can be easily saved and loaded as a string.
This session file can be easily saved and loaded as a string. According
to the initial design, it contains only the data that is necessary for
successful connection and authentication, so takeout ID is not stored.
It is thought to be used where you don't want to create any on-disk
files but would still like to be able to save and load existing sessions
@@ -33,7 +37,7 @@ class StringSession(MemorySession):
string = string[1:]
ip_len = 4 if len(string) == 352 else 16
self._dc_id, ip, self._port, key = struct.unpack(
'>B{}sH256s'.format(ip_len), StringSession.decode(string))
_STRUCT_PREFORMAT.format(ip_len), StringSession.decode(string))
self._server_address = ipaddress.ip_address(ip).compressed
if any(key):
@@ -45,7 +49,7 @@ class StringSession(MemorySession):
ip = ipaddress.ip_address(self._server_address).packed
return CURRENT_VERSION + StringSession.encode(struct.pack(
'>B{}sH256s'.format(len(ip)),
_STRUCT_PREFORMAT.format(len(ip)),
self._dc_id,
ip,
self._port,