Improve TakeoutClient proxy and takeout functionality (#1106)

This commit is contained in:
Dmitry D. Chernov 2019-02-10 20:10:41 +10:00 committed by Lonami
parent 274fa72a8c
commit 9a98d41a2c
6 changed files with 153 additions and 51 deletions

View File

@ -8,45 +8,65 @@ from ..tl import functions, TLRequest
class _TakeoutClient: class _TakeoutClient:
""" """
Proxy object over the client. `c` is the client, `k` it's class, Proxy object over the client.
`r` is the takeout request, and `t` is the takeout ID.
""" """
def __init__(self, client, request): __PROXY_INTERFACE = ('__enter__', '__exit__', '__aenter__', '__aexit__')
# We're a proxy object with __getattribute__overrode so we
# need to set attributes through the super class `object`. def __init__(self, finalize, client, request):
super().__setattr__('c', client) # We use the name mangling for attributes to make them inaccessible
super().__setattr__('k', client.__class__) # from within the shadowed client object and to distinguish them from
super().__setattr__('r', request) # its own attributes where needed.
super().__setattr__('t', None) self.__finalize = finalize
self.__client = client
self.__request = request
self.__success = None
@property
def success(self):
return self.__success
@success.setter
def success(self, value):
self.__success = value
def __enter__(self): def __enter__(self):
# We also get self attributes through super() if self.__client.loop.is_running():
if super().__getattribute__('c').loop.is_running():
raise RuntimeError( raise RuntimeError(
'You must use "async with" if the event loop ' 'You must use "async with" if the event loop '
'is running (i.e. you are inside an "async def")' 'is running (i.e. you are inside an "async def")'
) )
return super().__getattribute__( return self.__client.loop.run_until_complete(self.__aenter__())
'c').loop.run_until_complete(self.__aenter__())
async def __aenter__(self): async def __aenter__(self):
# Enter/Exit behaviour is "overrode", we don't want to call start # Enter/Exit behaviour is "overrode", we don't want to call start.
cl = super().__getattribute__('c') client = self.__client
super().__setattr__('t', (await cl(super().__getattribute__('r'))).id) if client.session.takeout_id is None:
client.session.takeout_id = (await client(self.__request)).id
elif self.__request is not None:
raise ValueError("Can't send a takeout request while another "
"takeout for the current session still not been finished yet.")
return self return self
def __exit__(self, *args): def __exit__(self, *args):
return super().__getattribute__( return self.__client.loop.run_until_complete(self.__aexit__(*args))
'c').loop.run_until_complete(self.__aexit__(*args))
async def __aexit__(self, *args): async def __aexit__(self, exc_type, exc_value, traceback):
super().__setattr__('t', None) if self.__success is None and self.__finalize:
self.__success = exc_type is None
if self.__success is not None:
result = await self(functions.account.FinishTakeoutSessionRequest(
self.__success))
if not result:
raise ValueError("Failed to finish the takeout.")
self.session.takeout_id = None
async def __call__(self, request, ordered=False): async def __call__(self, request, ordered=False):
takeout_id = super().__getattribute__('t') takeout_id = self.__client.session.takeout_id
if takeout_id is None: if takeout_id is None:
raise ValueError('Cannot call takeout methods outside of "with"') raise ValueError('Takeout mode has not been initialized '
'(are you calling outside of "with"?)')
single = not utils.is_list_like(request) single = not utils.is_list_like(request)
requests = ((request,) if single else request) requests = ((request,) if single else request)
@ -57,34 +77,43 @@ class _TakeoutClient:
await r.resolve(self, utils) await r.resolve(self, utils)
wrapped.append(functions.InvokeWithTakeoutRequest(takeout_id, r)) wrapped.append(functions.InvokeWithTakeoutRequest(takeout_id, r))
return await super().__getattribute__('c')( return await self.__client(
wrapped[0] if single else wrapped, ordered=ordered) wrapped[0] if single else wrapped, ordered=ordered)
def __getattribute__(self, name): def __getattribute__(self, name):
if name.startswith('__'): # We access class via type() because __class__ will recurse infinitely.
# We want to override special method names # Also note that since we've name-mangled our own class attributes,
if name == '__class__': # they'll be passed to __getattribute__() as already decorated. For
# See https://github.com/LonamiWebs/Telethon/issues/1103. # example, 'self.__client' will be passed as '_TakeoutClient__client'.
name = 'k' # https://docs.python.org/3/tutorial/classes.html#private-variables
if name.startswith('__') and name not in type(self).__PROXY_INTERFACE:
raise AttributeError # force call of __getattr__
# Try to access attribute in the proxy object and check for the same
# attribute in the shadowed object (through our __getattr__) if failed.
return super().__getattribute__(name) return super().__getattribute__(name)
value = getattr(super().__getattribute__('c'), name) def __getattr__(self, name):
value = getattr(self.__client, name)
if inspect.ismethod(value): if inspect.ismethod(value):
# Emulate bound methods behaviour by partially applying # Emulate bound methods behavior by partially applying our proxy
# our proxy class as the self parameter instead of the client # class as the self parameter instead of the client.
return functools.partial( return functools.partial(
getattr(super().__getattribute__('k'), name), self) getattr(self.__client.__class__, name), self)
else:
return value return value
def __setattr__(self, name, value): def __setattr__(self, name, value):
setattr(super().__getattribute__('c'), name, value) if name.startswith('_{}__'.format(type(self).__name__.lstrip('_'))):
# This is our own name-mangled attribute, keep calm.
return super().__setattr__(name, value)
return setattr(self.__client, name, value)
class AccountMethods(UserMethods): class AccountMethods(UserMethods):
def takeout( def takeout(
self, contacts=None, users=None, chats=None, megagroups=None, self, finalize=True, *, contacts=None, users=None, chats=None,
channels=None, files=None, max_file_size=None): megagroups=None, channels=None, files=None, max_file_size=None):
""" """
Creates a proxy object over the current :ref:`TelegramClient` through Creates a proxy object over the current :ref:`TelegramClient` through
which making requests will use :tl:`InvokeWithTakeoutRequest` to wrap which making requests will use :tl:`InvokeWithTakeoutRequest` to wrap
@ -105,14 +134,24 @@ class AccountMethods(UserMethods):
to adjust the `wait_time` of methods like `client.iter_messages to adjust the `wait_time` of methods like `client.iter_messages
<telethon.client.messages.MessageMethods.iter_messages>`. <telethon.client.messages.MessageMethods.iter_messages>`.
By default, all parameters are ``False``, and you need to enable By default, all parameters are ``None``, and you need to enable those
those you plan to use by setting them to ``True``. you plan to use by setting them to either ``True`` or ``False``.
You should ``except errors.TakeoutInitDelayError as e``, since this You should ``except errors.TakeoutInitDelayError as e``, since this
exception will raise depending on the condition of the session. You exception will raise depending on the condition of the session. You
can then access ``e.seconds`` to know how long you should wait for can then access ``e.seconds`` to know how long you should wait for
before calling the method again. before calling the method again.
There's also a `success` property available in the takeout proxy
object, so from the `with` body you can set the boolean result that
will be sent back to Telegram. But if it's left ``None`` as by
default, then the action is based on the `finalize` parameter. If
it's ``True`` then the takeout will be finished, and if no exception
occurred during it, then ``True`` will be considered as a result.
Otherwise, the takeout will not be finished and its ID will be
preserved for future usage as `client.session.takeout_id
<telethon.sessions.abstract.Session.takeout_id>`.
Args: Args:
contacts (`bool`): contacts (`bool`):
Set to ``True`` if you plan on downloading contacts. Set to ``True`` if you plan on downloading contacts.
@ -141,7 +180,7 @@ class AccountMethods(UserMethods):
The maximum file size, in bytes, that you plan The maximum file size, in bytes, that you plan
to download for each message with media. to download for each message with media.
""" """
return _TakeoutClient(self, functions.account.InitTakeoutSessionRequest( request_kwargs = dict(
contacts=contacts, contacts=contacts,
message_users=users, message_users=users,
message_chats=chats, message_chats=chats,
@ -149,4 +188,27 @@ class AccountMethods(UserMethods):
message_channels=channels, message_channels=channels,
files=files, files=files,
file_max_size=max_file_size file_max_size=max_file_size
)) )
arg_specified = (arg is not None for arg in request_kwargs.values())
if self.session.takeout_id is None or any(arg_specified):
request = functions.account.InitTakeoutSessionRequest(
**request_kwargs)
else:
request = None
return _TakeoutClient(finalize, self, request)
async def end_takeout(self, success):
"""
Finishes a takeout, with specified result sent back to Telegram.
Returns:
``True`` if the operation was successful, ``False`` otherwise.
"""
try:
async with _TakeoutClient(True, self, None) as takeout:
takeout.success = success
except ValueError:
return False
return True

View File

@ -262,7 +262,6 @@ class TelegramBaseClient(abc.ABC):
) )
) )
self._connection = connection
self._sender = MTProtoSender( self._sender = MTProtoSender(
self.session.auth_key, self._loop, self.session.auth_key, self._loop,
loggers=self._log, loggers=self._log,

View File

@ -53,6 +53,23 @@ class Session(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@property
@abstractmethod
def takeout_id(self):
"""
Returns an ID of the takeout process initialized for this session,
or ``None`` if there's no were any unfinished takeout requests.
"""
raise NotImplementedError
@takeout_id.setter
@abstractmethod
def takeout_id(self, value):
"""
Sets the ID of the unfinished takeout process for this session.
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def get_update_state(self, entity_id): def get_update_state(self, entity_id):
""" """

View File

@ -32,6 +32,7 @@ class MemorySession(Session):
self._server_address = None self._server_address = None
self._port = None self._port = None
self._auth_key = None self._auth_key = None
self._takeout_id = None
self._files = {} self._files = {}
self._entities = set() self._entities = set()
@ -62,6 +63,14 @@ class MemorySession(Session):
def auth_key(self, value): def auth_key(self, value):
self._auth_key = value self._auth_key = value
@property
def takeout_id(self):
return self._takeout_id
@takeout_id.setter
def takeout_id(self, value):
self._takeout_id = value
def get_update_state(self, entity_id): def get_update_state(self, entity_id):
return self._update_states.get(entity_id, None) return self._update_states.get(entity_id, None)

View File

@ -18,7 +18,7 @@ except ImportError:
sqlite3 = None sqlite3 = None
EXTENSION = '.session' EXTENSION = '.session'
CURRENT_VERSION = 4 # database version CURRENT_VERSION = 5 # database version
class SQLiteSession(MemorySession): class SQLiteSession(MemorySession):
@ -65,7 +65,8 @@ class SQLiteSession(MemorySession):
c.execute('select * from sessions') c.execute('select * from sessions')
tuple_ = c.fetchone() tuple_ = c.fetchone()
if tuple_: if tuple_:
self._dc_id, self._server_address, self._port, key, = tuple_ self._dc_id, self._server_address, self._port, key, \
self._takeout_id = tuple_
self._auth_key = AuthKey(data=key) self._auth_key = AuthKey(data=key)
c.close() c.close()
@ -79,7 +80,8 @@ class SQLiteSession(MemorySession):
dc_id integer primary key, dc_id integer primary key,
server_address text, server_address text,
port integer, port integer,
auth_key blob auth_key blob,
takeout_id integer
)""" )"""
, ,
"""entities ( """entities (
@ -172,6 +174,9 @@ class SQLiteSession(MemorySession):
date integer, date integer,
seq integer seq integer
)""") )""")
if old == 4:
old += 1
c.execute("alter table sessions add column takeout_id integer")
c.close() c.close()
@staticmethod @staticmethod
@ -197,6 +202,11 @@ class SQLiteSession(MemorySession):
self._auth_key = value self._auth_key = value
self._update_session_table() self._update_session_table()
@MemorySession.takeout_id.setter
def takeout_id(self, value):
self._takeout_id = value
self._update_session_table()
def _update_session_table(self): def _update_session_table(self):
c = self._cursor() c = self._cursor()
# While we can save multiple rows into the sessions table # While we can save multiple rows into the sessions table
@ -205,11 +215,12 @@ class SQLiteSession(MemorySession):
# some more work before being able to save auth_key's for # some more work before being able to save auth_key's for
# multiple DCs. Probably done differently. # multiple DCs. Probably done differently.
c.execute('delete from sessions') c.execute('delete from sessions')
c.execute('insert or replace into sessions values (?,?,?,?)', ( c.execute('insert or replace into sessions values (?,?,?,?,?)', (
self._dc_id, self._dc_id,
self._server_address, self._server_address,
self._port, self._port,
self._auth_key.key if self._auth_key else b'' self._auth_key.key if self._auth_key else b'',
self._takeout_id
)) ))
c.close() c.close()

View File

@ -5,12 +5,16 @@ import struct
from .memory import MemorySession from .memory import MemorySession
from ..crypto import AuthKey from ..crypto import AuthKey
_STRUCT_PREFORMAT = '>B{}sH256s'
CURRENT_VERSION = '1' CURRENT_VERSION = '1'
class StringSession(MemorySession): class StringSession(MemorySession):
""" """
This minimal session file can be easily saved and loaded as a string. This session file can be easily saved and loaded as a string. According
to the initial design, it contains only the data that is necessary for
successful connection and authentication, so takeout ID is not stored.
It is thought to be used where you don't want to create any on-disk It is thought to be used where you don't want to create any on-disk
files but would still like to be able to save and load existing sessions files but would still like to be able to save and load existing sessions
@ -33,7 +37,7 @@ class StringSession(MemorySession):
string = string[1:] string = string[1:]
ip_len = 4 if len(string) == 352 else 16 ip_len = 4 if len(string) == 352 else 16
self._dc_id, ip, self._port, key = struct.unpack( self._dc_id, ip, self._port, key = struct.unpack(
'>B{}sH256s'.format(ip_len), StringSession.decode(string)) _STRUCT_PREFORMAT.format(ip_len), StringSession.decode(string))
self._server_address = ipaddress.ip_address(ip).compressed self._server_address = ipaddress.ip_address(ip).compressed
if any(key): if any(key):
@ -45,7 +49,7 @@ class StringSession(MemorySession):
ip = ipaddress.ip_address(self._server_address).packed ip = ipaddress.ip_address(self._server_address).packed
return CURRENT_VERSION + StringSession.encode(struct.pack( return CURRENT_VERSION + StringSession.encode(struct.pack(
'>B{}sH256s'.format(len(ip)), _STRUCT_PREFORMAT.format(len(ip)),
self._dc_id, self._dc_id,
ip, ip,
self._port, self._port,