mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-06-17 10:36:37 +00:00
Improve TakeoutClient proxy and takeout functionality (#1106)
This commit is contained in:
parent
274fa72a8c
commit
9a98d41a2c
@ -8,45 +8,65 @@ from ..tl import functions, TLRequest
|
|||||||
|
|
||||||
class _TakeoutClient:
|
class _TakeoutClient:
|
||||||
"""
|
"""
|
||||||
Proxy object over the client. `c` is the client, `k` it's class,
|
Proxy object over the client.
|
||||||
`r` is the takeout request, and `t` is the takeout ID.
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, client, request):
|
__PROXY_INTERFACE = ('__enter__', '__exit__', '__aenter__', '__aexit__')
|
||||||
# We're a proxy object with __getattribute__overrode so we
|
|
||||||
# need to set attributes through the super class `object`.
|
def __init__(self, finalize, client, request):
|
||||||
super().__setattr__('c', client)
|
# We use the name mangling for attributes to make them inaccessible
|
||||||
super().__setattr__('k', client.__class__)
|
# from within the shadowed client object and to distinguish them from
|
||||||
super().__setattr__('r', request)
|
# its own attributes where needed.
|
||||||
super().__setattr__('t', None)
|
self.__finalize = finalize
|
||||||
|
self.__client = client
|
||||||
|
self.__request = request
|
||||||
|
self.__success = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def success(self):
|
||||||
|
return self.__success
|
||||||
|
|
||||||
|
@success.setter
|
||||||
|
def success(self, value):
|
||||||
|
self.__success = value
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
# We also get self attributes through super()
|
if self.__client.loop.is_running():
|
||||||
if super().__getattribute__('c').loop.is_running():
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'You must use "async with" if the event loop '
|
'You must use "async with" if the event loop '
|
||||||
'is running (i.e. you are inside an "async def")'
|
'is running (i.e. you are inside an "async def")'
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().__getattribute__(
|
return self.__client.loop.run_until_complete(self.__aenter__())
|
||||||
'c').loop.run_until_complete(self.__aenter__())
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
# Enter/Exit behaviour is "overrode", we don't want to call start
|
# Enter/Exit behaviour is "overrode", we don't want to call start.
|
||||||
cl = super().__getattribute__('c')
|
client = self.__client
|
||||||
super().__setattr__('t', (await cl(super().__getattribute__('r'))).id)
|
if client.session.takeout_id is None:
|
||||||
|
client.session.takeout_id = (await client(self.__request)).id
|
||||||
|
elif self.__request is not None:
|
||||||
|
raise ValueError("Can't send a takeout request while another "
|
||||||
|
"takeout for the current session still not been finished yet.")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
return super().__getattribute__(
|
return self.__client.loop.run_until_complete(self.__aexit__(*args))
|
||||||
'c').loop.run_until_complete(self.__aexit__(*args))
|
|
||||||
|
|
||||||
async def __aexit__(self, *args):
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
super().__setattr__('t', None)
|
if self.__success is None and self.__finalize:
|
||||||
|
self.__success = exc_type is None
|
||||||
|
|
||||||
|
if self.__success is not None:
|
||||||
|
result = await self(functions.account.FinishTakeoutSessionRequest(
|
||||||
|
self.__success))
|
||||||
|
if not result:
|
||||||
|
raise ValueError("Failed to finish the takeout.")
|
||||||
|
self.session.takeout_id = None
|
||||||
|
|
||||||
async def __call__(self, request, ordered=False):
|
async def __call__(self, request, ordered=False):
|
||||||
takeout_id = super().__getattribute__('t')
|
takeout_id = self.__client.session.takeout_id
|
||||||
if takeout_id is None:
|
if takeout_id is None:
|
||||||
raise ValueError('Cannot call takeout methods outside of "with"')
|
raise ValueError('Takeout mode has not been initialized '
|
||||||
|
'(are you calling outside of "with"?)')
|
||||||
|
|
||||||
single = not utils.is_list_like(request)
|
single = not utils.is_list_like(request)
|
||||||
requests = ((request,) if single else request)
|
requests = ((request,) if single else request)
|
||||||
@ -57,34 +77,43 @@ class _TakeoutClient:
|
|||||||
await r.resolve(self, utils)
|
await r.resolve(self, utils)
|
||||||
wrapped.append(functions.InvokeWithTakeoutRequest(takeout_id, r))
|
wrapped.append(functions.InvokeWithTakeoutRequest(takeout_id, r))
|
||||||
|
|
||||||
return await super().__getattribute__('c')(
|
return await self.__client(
|
||||||
wrapped[0] if single else wrapped, ordered=ordered)
|
wrapped[0] if single else wrapped, ordered=ordered)
|
||||||
|
|
||||||
def __getattribute__(self, name):
|
def __getattribute__(self, name):
|
||||||
if name.startswith('__'):
|
# We access class via type() because __class__ will recurse infinitely.
|
||||||
# We want to override special method names
|
# Also note that since we've name-mangled our own class attributes,
|
||||||
if name == '__class__':
|
# they'll be passed to __getattribute__() as already decorated. For
|
||||||
# See https://github.com/LonamiWebs/Telethon/issues/1103.
|
# example, 'self.__client' will be passed as '_TakeoutClient__client'.
|
||||||
name = 'k'
|
# https://docs.python.org/3/tutorial/classes.html#private-variables
|
||||||
return super().__getattribute__(name)
|
if name.startswith('__') and name not in type(self).__PROXY_INTERFACE:
|
||||||
|
raise AttributeError # force call of __getattr__
|
||||||
|
|
||||||
value = getattr(super().__getattribute__('c'), name)
|
# Try to access attribute in the proxy object and check for the same
|
||||||
|
# attribute in the shadowed object (through our __getattr__) if failed.
|
||||||
|
return super().__getattribute__(name)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
value = getattr(self.__client, name)
|
||||||
if inspect.ismethod(value):
|
if inspect.ismethod(value):
|
||||||
# Emulate bound methods behaviour by partially applying
|
# Emulate bound methods behavior by partially applying our proxy
|
||||||
# our proxy class as the self parameter instead of the client
|
# class as the self parameter instead of the client.
|
||||||
return functools.partial(
|
return functools.partial(
|
||||||
getattr(super().__getattribute__('k'), name), self)
|
getattr(self.__client.__class__, name), self)
|
||||||
else:
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
setattr(super().__getattribute__('c'), name, value)
|
if name.startswith('_{}__'.format(type(self).__name__.lstrip('_'))):
|
||||||
|
# This is our own name-mangled attribute, keep calm.
|
||||||
|
return super().__setattr__(name, value)
|
||||||
|
return setattr(self.__client, name, value)
|
||||||
|
|
||||||
|
|
||||||
class AccountMethods(UserMethods):
|
class AccountMethods(UserMethods):
|
||||||
def takeout(
|
def takeout(
|
||||||
self, contacts=None, users=None, chats=None, megagroups=None,
|
self, finalize=True, *, contacts=None, users=None, chats=None,
|
||||||
channels=None, files=None, max_file_size=None):
|
megagroups=None, channels=None, files=None, max_file_size=None):
|
||||||
"""
|
"""
|
||||||
Creates a proxy object over the current :ref:`TelegramClient` through
|
Creates a proxy object over the current :ref:`TelegramClient` through
|
||||||
which making requests will use :tl:`InvokeWithTakeoutRequest` to wrap
|
which making requests will use :tl:`InvokeWithTakeoutRequest` to wrap
|
||||||
@ -105,14 +134,24 @@ class AccountMethods(UserMethods):
|
|||||||
to adjust the `wait_time` of methods like `client.iter_messages
|
to adjust the `wait_time` of methods like `client.iter_messages
|
||||||
<telethon.client.messages.MessageMethods.iter_messages>`.
|
<telethon.client.messages.MessageMethods.iter_messages>`.
|
||||||
|
|
||||||
By default, all parameters are ``False``, and you need to enable
|
By default, all parameters are ``None``, and you need to enable those
|
||||||
those you plan to use by setting them to ``True``.
|
you plan to use by setting them to either ``True`` or ``False``.
|
||||||
|
|
||||||
You should ``except errors.TakeoutInitDelayError as e``, since this
|
You should ``except errors.TakeoutInitDelayError as e``, since this
|
||||||
exception will raise depending on the condition of the session. You
|
exception will raise depending on the condition of the session. You
|
||||||
can then access ``e.seconds`` to know how long you should wait for
|
can then access ``e.seconds`` to know how long you should wait for
|
||||||
before calling the method again.
|
before calling the method again.
|
||||||
|
|
||||||
|
There's also a `success` property available in the takeout proxy
|
||||||
|
object, so from the `with` body you can set the boolean result that
|
||||||
|
will be sent back to Telegram. But if it's left ``None`` as by
|
||||||
|
default, then the action is based on the `finalize` parameter. If
|
||||||
|
it's ``True`` then the takeout will be finished, and if no exception
|
||||||
|
occurred during it, then ``True`` will be considered as a result.
|
||||||
|
Otherwise, the takeout will not be finished and its ID will be
|
||||||
|
preserved for future usage as `client.session.takeout_id
|
||||||
|
<telethon.sessions.abstract.Session.takeout_id>`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
contacts (`bool`):
|
contacts (`bool`):
|
||||||
Set to ``True`` if you plan on downloading contacts.
|
Set to ``True`` if you plan on downloading contacts.
|
||||||
@ -141,7 +180,7 @@ class AccountMethods(UserMethods):
|
|||||||
The maximum file size, in bytes, that you plan
|
The maximum file size, in bytes, that you plan
|
||||||
to download for each message with media.
|
to download for each message with media.
|
||||||
"""
|
"""
|
||||||
return _TakeoutClient(self, functions.account.InitTakeoutSessionRequest(
|
request_kwargs = dict(
|
||||||
contacts=contacts,
|
contacts=contacts,
|
||||||
message_users=users,
|
message_users=users,
|
||||||
message_chats=chats,
|
message_chats=chats,
|
||||||
@ -149,4 +188,27 @@ class AccountMethods(UserMethods):
|
|||||||
message_channels=channels,
|
message_channels=channels,
|
||||||
files=files,
|
files=files,
|
||||||
file_max_size=max_file_size
|
file_max_size=max_file_size
|
||||||
))
|
)
|
||||||
|
arg_specified = (arg is not None for arg in request_kwargs.values())
|
||||||
|
|
||||||
|
if self.session.takeout_id is None or any(arg_specified):
|
||||||
|
request = functions.account.InitTakeoutSessionRequest(
|
||||||
|
**request_kwargs)
|
||||||
|
else:
|
||||||
|
request = None
|
||||||
|
|
||||||
|
return _TakeoutClient(finalize, self, request)
|
||||||
|
|
||||||
|
async def end_takeout(self, success):
|
||||||
|
"""
|
||||||
|
Finishes a takeout, with specified result sent back to Telegram.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``True`` if the operation was successful, ``False`` otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with _TakeoutClient(True, self, None) as takeout:
|
||||||
|
takeout.success = success
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
@ -262,7 +262,6 @@ class TelegramBaseClient(abc.ABC):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._connection = connection
|
|
||||||
self._sender = MTProtoSender(
|
self._sender = MTProtoSender(
|
||||||
self.session.auth_key, self._loop,
|
self.session.auth_key, self._loop,
|
||||||
loggers=self._log,
|
loggers=self._log,
|
||||||
|
@ -53,6 +53,23 @@ class Session(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def takeout_id(self):
|
||||||
|
"""
|
||||||
|
Returns an ID of the takeout process initialized for this session,
|
||||||
|
or ``None`` if there's no were any unfinished takeout requests.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@takeout_id.setter
|
||||||
|
@abstractmethod
|
||||||
|
def takeout_id(self, value):
|
||||||
|
"""
|
||||||
|
Sets the ID of the unfinished takeout process for this session.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_update_state(self, entity_id):
|
def get_update_state(self, entity_id):
|
||||||
"""
|
"""
|
||||||
|
@ -32,6 +32,7 @@ class MemorySession(Session):
|
|||||||
self._server_address = None
|
self._server_address = None
|
||||||
self._port = None
|
self._port = None
|
||||||
self._auth_key = None
|
self._auth_key = None
|
||||||
|
self._takeout_id = None
|
||||||
|
|
||||||
self._files = {}
|
self._files = {}
|
||||||
self._entities = set()
|
self._entities = set()
|
||||||
@ -62,6 +63,14 @@ class MemorySession(Session):
|
|||||||
def auth_key(self, value):
|
def auth_key(self, value):
|
||||||
self._auth_key = value
|
self._auth_key = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def takeout_id(self):
|
||||||
|
return self._takeout_id
|
||||||
|
|
||||||
|
@takeout_id.setter
|
||||||
|
def takeout_id(self, value):
|
||||||
|
self._takeout_id = value
|
||||||
|
|
||||||
def get_update_state(self, entity_id):
|
def get_update_state(self, entity_id):
|
||||||
return self._update_states.get(entity_id, None)
|
return self._update_states.get(entity_id, None)
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ except ImportError:
|
|||||||
sqlite3 = None
|
sqlite3 = None
|
||||||
|
|
||||||
EXTENSION = '.session'
|
EXTENSION = '.session'
|
||||||
CURRENT_VERSION = 4 # database version
|
CURRENT_VERSION = 5 # database version
|
||||||
|
|
||||||
|
|
||||||
class SQLiteSession(MemorySession):
|
class SQLiteSession(MemorySession):
|
||||||
@ -65,7 +65,8 @@ class SQLiteSession(MemorySession):
|
|||||||
c.execute('select * from sessions')
|
c.execute('select * from sessions')
|
||||||
tuple_ = c.fetchone()
|
tuple_ = c.fetchone()
|
||||||
if tuple_:
|
if tuple_:
|
||||||
self._dc_id, self._server_address, self._port, key, = tuple_
|
self._dc_id, self._server_address, self._port, key, \
|
||||||
|
self._takeout_id = tuple_
|
||||||
self._auth_key = AuthKey(data=key)
|
self._auth_key = AuthKey(data=key)
|
||||||
|
|
||||||
c.close()
|
c.close()
|
||||||
@ -79,7 +80,8 @@ class SQLiteSession(MemorySession):
|
|||||||
dc_id integer primary key,
|
dc_id integer primary key,
|
||||||
server_address text,
|
server_address text,
|
||||||
port integer,
|
port integer,
|
||||||
auth_key blob
|
auth_key blob,
|
||||||
|
takeout_id integer
|
||||||
)"""
|
)"""
|
||||||
,
|
,
|
||||||
"""entities (
|
"""entities (
|
||||||
@ -172,6 +174,9 @@ class SQLiteSession(MemorySession):
|
|||||||
date integer,
|
date integer,
|
||||||
seq integer
|
seq integer
|
||||||
)""")
|
)""")
|
||||||
|
if old == 4:
|
||||||
|
old += 1
|
||||||
|
c.execute("alter table sessions add column takeout_id integer")
|
||||||
c.close()
|
c.close()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -197,6 +202,11 @@ class SQLiteSession(MemorySession):
|
|||||||
self._auth_key = value
|
self._auth_key = value
|
||||||
self._update_session_table()
|
self._update_session_table()
|
||||||
|
|
||||||
|
@MemorySession.takeout_id.setter
|
||||||
|
def takeout_id(self, value):
|
||||||
|
self._takeout_id = value
|
||||||
|
self._update_session_table()
|
||||||
|
|
||||||
def _update_session_table(self):
|
def _update_session_table(self):
|
||||||
c = self._cursor()
|
c = self._cursor()
|
||||||
# While we can save multiple rows into the sessions table
|
# While we can save multiple rows into the sessions table
|
||||||
@ -205,11 +215,12 @@ class SQLiteSession(MemorySession):
|
|||||||
# some more work before being able to save auth_key's for
|
# some more work before being able to save auth_key's for
|
||||||
# multiple DCs. Probably done differently.
|
# multiple DCs. Probably done differently.
|
||||||
c.execute('delete from sessions')
|
c.execute('delete from sessions')
|
||||||
c.execute('insert or replace into sessions values (?,?,?,?)', (
|
c.execute('insert or replace into sessions values (?,?,?,?,?)', (
|
||||||
self._dc_id,
|
self._dc_id,
|
||||||
self._server_address,
|
self._server_address,
|
||||||
self._port,
|
self._port,
|
||||||
self._auth_key.key if self._auth_key else b''
|
self._auth_key.key if self._auth_key else b'',
|
||||||
|
self._takeout_id
|
||||||
))
|
))
|
||||||
c.close()
|
c.close()
|
||||||
|
|
||||||
|
@ -5,12 +5,16 @@ import struct
|
|||||||
from .memory import MemorySession
|
from .memory import MemorySession
|
||||||
from ..crypto import AuthKey
|
from ..crypto import AuthKey
|
||||||
|
|
||||||
|
_STRUCT_PREFORMAT = '>B{}sH256s'
|
||||||
|
|
||||||
CURRENT_VERSION = '1'
|
CURRENT_VERSION = '1'
|
||||||
|
|
||||||
|
|
||||||
class StringSession(MemorySession):
|
class StringSession(MemorySession):
|
||||||
"""
|
"""
|
||||||
This minimal session file can be easily saved and loaded as a string.
|
This session file can be easily saved and loaded as a string. According
|
||||||
|
to the initial design, it contains only the data that is necessary for
|
||||||
|
successful connection and authentication, so takeout ID is not stored.
|
||||||
|
|
||||||
It is thought to be used where you don't want to create any on-disk
|
It is thought to be used where you don't want to create any on-disk
|
||||||
files but would still like to be able to save and load existing sessions
|
files but would still like to be able to save and load existing sessions
|
||||||
@ -33,7 +37,7 @@ class StringSession(MemorySession):
|
|||||||
string = string[1:]
|
string = string[1:]
|
||||||
ip_len = 4 if len(string) == 352 else 16
|
ip_len = 4 if len(string) == 352 else 16
|
||||||
self._dc_id, ip, self._port, key = struct.unpack(
|
self._dc_id, ip, self._port, key = struct.unpack(
|
||||||
'>B{}sH256s'.format(ip_len), StringSession.decode(string))
|
_STRUCT_PREFORMAT.format(ip_len), StringSession.decode(string))
|
||||||
|
|
||||||
self._server_address = ipaddress.ip_address(ip).compressed
|
self._server_address = ipaddress.ip_address(ip).compressed
|
||||||
if any(key):
|
if any(key):
|
||||||
@ -45,7 +49,7 @@ class StringSession(MemorySession):
|
|||||||
|
|
||||||
ip = ipaddress.ip_address(self._server_address).packed
|
ip = ipaddress.ip_address(self._server_address).packed
|
||||||
return CURRENT_VERSION + StringSession.encode(struct.pack(
|
return CURRENT_VERSION + StringSession.encode(struct.pack(
|
||||||
'>B{}sH256s'.format(len(ip)),
|
_STRUCT_PREFORMAT.format(len(ip)),
|
||||||
self._dc_id,
|
self._dc_id,
|
||||||
ip,
|
ip,
|
||||||
self._port,
|
self._port,
|
||||||
|
Loading…
Reference in New Issue
Block a user