diff --git a/telethon/client/auth.py b/telethon/client/auth.py index 9665262b..dff4e1e3 100644 --- a/telethon/client/auth.py +++ b/telethon/client/auth.py @@ -597,7 +597,7 @@ class AuthMethods: self._state_cache.reset() await self.disconnect() - self.session.delete() + await self.session.delete() return True async def edit_2fa( diff --git a/telethon/client/downloads.py b/telethon/client/downloads.py index 6d7a8d65..bace4211 100644 --- a/telethon/client/downloads.py +++ b/telethon/client/downloads.py @@ -55,7 +55,7 @@ class _DirectDownloadIter(RequestIter): if option.ip_address == self.client.session.server_address: self.client.session.set_dc( option.id, option.ip_address, option.port) - self.client.session.save() + await self.client.session.save() break # TODO Figure out why the session may have the wrong DC ID @@ -402,7 +402,7 @@ class DownloadMethods: if isinstance(message.action, types.MessageActionChatEditPhoto): media = media.photo - + if isinstance(media, types.MessageMediaWebPage): if isinstance(media.webpage, types.WebPage): media = media.webpage.document or media.webpage.photo diff --git a/telethon/client/messages.py b/telethon/client/messages.py index 01011b58..cda410d1 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -1019,7 +1019,7 @@ class MessageMethods: async def edit_message( self: 'TelegramClient', entity: 'typing.Union[hints.EntityLike, types.Message]', - message: 'hints.MessageLike' = None, + message: 'hints.MessageIDLike' = None, text: str = None, *, parse_mode: str = (), diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index 494daf9c..46ee312f 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -412,10 +412,7 @@ class TelegramBaseClient(abc.ABC): self._authorized = None # None = unknown, False = no, True = yes - # Update state (for catching up after a disconnection) - # TODO Get state from channels too - self._state_cache = StateCache( - self.session.get_update_state(0), self._log) + self._state_cache = StateCache(None, self._log) # Some further state for subclasses self._event_builders = [] @@ -522,6 +519,11 @@ class TelegramBaseClient(abc.ABC): except OSError: print('Failed to connect') """ + # Update state (for catching up after a disconnection) + # TODO Get state from channels too + self._state_cache = StateCache( + await self.session.get_update_state(0), self._log) + if not await self._sender.connect(self._connection( self.session.server_address, self.session.port, @@ -534,7 +536,7 @@ class TelegramBaseClient(abc.ABC): return self.session.auth_key = self._sender.auth_key - self.session.save() + await self.session.save() self._init_request.query = functions.help.GetConfigRequest() @@ -644,7 +646,7 @@ class TelegramBaseClient(abc.ABC): pts, date = self._state_cache[None] if pts and date: - self.session.set_update_state(0, types.updates.State( + await self.session.set_update_state(0, types.updates.State( pts=pts, qts=0, date=date, @@ -652,7 +654,7 @@ class TelegramBaseClient(abc.ABC): unread_count=0 )) - self.session.close() + await self.session.close() async def _disconnect(self: 'TelegramClient'): """ @@ -677,17 +679,17 @@ class TelegramBaseClient(abc.ABC): # so it's not valid anymore. Set to None to force recreating it. self._sender.auth_key.key = None self.session.auth_key = None - self.session.save() + await self.session.save() await self._disconnect() return await self.connect() - def _auth_key_callback(self: 'TelegramClient', auth_key): + async def _auth_key_callback(self: 'TelegramClient', auth_key): """ Callback from the sender whenever it needed to generate a new authorization key. This means we are not authorized. """ self.session.auth_key = auth_key - self.session.save() + await self.session.save() # endregion @@ -812,7 +814,7 @@ class TelegramBaseClient(abc.ABC): if not session: dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) session = self.session.clone() - await session.set_dc(dc.id, dc.ip_address, dc.port) + session.set_dc(dc.id, dc.ip_address, dc.port) self._exported_sessions[cdn_redirect.dc_id] = session self._log[__name__].info('Creating new CDN client') diff --git a/telethon/client/updates.py b/telethon/client/updates.py index bcc983f3..65cff607 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -255,7 +255,7 @@ class UpdateMethods: state = d.intermediate_state pts, date = state.pts, state.date - self._handle_update(types.Updates( + await self._handle_update(types.Updates( users=d.users, chats=d.chats, date=state.date, @@ -300,8 +300,8 @@ class UpdateMethods: # It is important to not make _handle_update async because we rely on # the order that the updates arrive in to update the pts and date to # be always-increasing. There is also no need to make this async. - def _handle_update(self: 'TelegramClient', update): - self.session.process_entities(update) + async def _handle_update(self: 'TelegramClient', update): + await self.session.process_entities(update) self._entity_cache.add(update) if isinstance(update, (types.Updates, types.UpdatesCombined)): @@ -372,7 +372,7 @@ class UpdateMethods: # inserted because this is a rather expensive operation # (default's sqlite3 takes ~0.1s to commit changes). Do # it every minute instead. No-op if there's nothing new. - self.session.save() + await self.session.save() # We need to send some content-related request at least hourly # for Telegram to keep delivering updates, otherwise they will diff --git a/telethon/client/users.py b/telethon/client/users.py index 22db969e..615d97cf 100644 --- a/telethon/client/users.py +++ b/telethon/client/users.py @@ -71,7 +71,7 @@ class UserMethods: exceptions.append(e) results.append(None) continue - self.session.process_entities(result) + await self.session.process_entities(result) self._entity_cache.add(result) exceptions.append(None) results.append(result) @@ -82,7 +82,7 @@ class UserMethods: return results else: result = await future - self.session.process_entities(result) + await self.session.process_entities(result) self._entity_cache.add(result) return result except (errors.ServerError, errors.RpcCallFailError, @@ -427,7 +427,7 @@ class UserMethods: # No InputPeer, cached peer, or known string. Fetch from disk cache try: - return self.session.get_input_entity(peer) + return await self.session.get_input_entity(peer) except ValueError: pass @@ -567,7 +567,7 @@ class UserMethods: try: # Nobody with this username, maybe it's an exact name/title return await self.get_entity( - self.session.get_input_entity(string)) + await self.session.get_input_entity(string)) except ValueError: pass diff --git a/telethon/network/mtprotosender.py b/telethon/network/mtprotosender.py index ca592ac0..05fd39e9 100644 --- a/telethon/network/mtprotosender.py +++ b/telethon/network/mtprotosender.py @@ -295,7 +295,7 @@ class MTProtoSender: # notify whenever we change it. This is crucial when we # switch to different data centers. if self._auth_key_callback: - self._auth_key_callback(self.auth_key) + await self._auth_key_callback(self.auth_key) self._log.debug('auth_key generation success!') return True @@ -380,7 +380,7 @@ class MTProtoSender: self._log.info('Broken authorization key; resetting') self.auth_key.key = None if self._auth_key_callback: - self._auth_key_callback(None) + await self._auth_key_callback(None) ok = False break @@ -524,7 +524,7 @@ class MTProtoSender: self._log.info('Broken authorization key; resetting') self.auth_key.key = None if self._auth_key_callback: - self._auth_key_callback(None) + await self._auth_key_callback(None) await self._disconnect(error=e) else: @@ -653,7 +653,7 @@ class MTProtoSender: self._log.debug('Handling update %s', message.obj.__class__.__name__) if self._update_callback: - self._update_callback(message.obj) + await self._update_callback(message.obj) async def _handle_pong(self, message): """ diff --git a/telethon/sessions/abstract.py b/telethon/sessions/abstract.py index 5fda1c18..94afac86 100644 --- a/telethon/sessions/abstract.py +++ b/telethon/sessions/abstract.py @@ -79,7 +79,7 @@ class Session(ABC): raise NotImplementedError @abstractmethod - def get_update_state(self, entity_id): + async def get_update_state(self, entity_id): """ Returns the ``UpdateState`` associated with the given `entity_id`. If the `entity_id` is 0, it should return the ``UpdateState`` for @@ -89,7 +89,7 @@ class Session(ABC): raise NotImplementedError @abstractmethod - def set_update_state(self, entity_id, state): + async def set_update_state(self, entity_id, state): """ Sets the given ``UpdateState`` for the specified `entity_id`, which should be 0 if the ``UpdateState`` is the "general" state (and not @@ -98,14 +98,14 @@ class Session(ABC): raise NotImplementedError @abstractmethod - def close(self): + async def close(self): """ Called on client disconnection. Should be used to free any used resources. Can be left empty if none. """ @abstractmethod - def save(self): + async def save(self): """ Called whenever important properties change. It should make persist the relevant session information to disk. @@ -113,22 +113,15 @@ class Session(ABC): raise NotImplementedError @abstractmethod - def delete(self): + async def delete(self): """ Called upon client.log_out(). Should delete the stored information from disk since it's not valid anymore. """ raise NotImplementedError - @classmethod - def list_sessions(cls): - """ - Lists available sessions. Not used by the library itself. - """ - return [] - @abstractmethod - def process_entities(self, tlo): + async def process_entities(self, tlo): """ Processes the input ``TLObject`` or ``list`` and saves whatever information is relevant (e.g., ID or access hash). @@ -136,7 +129,7 @@ class Session(ABC): raise NotImplementedError @abstractmethod - def get_input_entity(self, key): + async def get_input_entity(self, key): """ Turns the given key into an ``InputPeer`` (e.g. ``InputPeerUser``). The library uses this method whenever an ``InputPeer`` is needed @@ -144,24 +137,3 @@ class Session(ABC): to use a cached username to avoid extra RPC). """ raise NotImplementedError - - @abstractmethod - def cache_file(self, md5_digest, file_size, instance): - """ - Caches the given file information persistently, so that it - doesn't need to be re-uploaded in case the file is used again. - - The ``instance`` will be either an ``InputPhoto`` or ``InputDocument``, - both with an ``.id`` and ``.access_hash`` attributes. - """ - raise NotImplementedError - - @abstractmethod - def get_file(self, md5_digest, file_size, cls): - """ - Returns an instance of ``cls`` if the ``md5_digest`` and ``file_size`` - match an existing saved record. The class will either be an - ``InputPhoto`` or ``InputDocument``, both with two parameters - ``id`` and ``access_hash`` in that order. - """ - raise NotImplementedError diff --git a/telethon/sessions/memory.py b/telethon/sessions/memory.py index 1b1a6bfb..8494545c 100644 --- a/telethon/sessions/memory.py +++ b/telethon/sessions/memory.py @@ -71,19 +71,19 @@ class MemorySession(Session): def takeout_id(self, value): self._takeout_id = value - def get_update_state(self, entity_id): + async def get_update_state(self, entity_id): return self._update_states.get(entity_id, None) - def set_update_state(self, entity_id, state): + async def set_update_state(self, entity_id, state): self._update_states[entity_id] = state - def close(self): + async def close(self): pass - def save(self): + async def save(self): pass - def delete(self): + async def delete(self): pass @staticmethod @@ -144,31 +144,31 @@ class MemorySession(Session): rows.append(row) return rows - def process_entities(self, tlo): + async def process_entities(self, tlo): self._entities |= set(self._entities_to_rows(tlo)) - def get_entity_rows_by_phone(self, phone): + async def get_entity_rows_by_phone(self, phone): try: return next((id, hash) for id, hash, _, found_phone, _ in self._entities if found_phone == phone) except StopIteration: pass - def get_entity_rows_by_username(self, username): + async def get_entity_rows_by_username(self, username): try: return next((id, hash) for id, hash, found_username, _, _ in self._entities if found_username == username) except StopIteration: pass - def get_entity_rows_by_name(self, name): + async def get_entity_rows_by_name(self, name): try: return next((id, hash) for id, hash, _, _, found_name in self._entities if found_name == name) except StopIteration: pass - def get_entity_rows_by_id(self, id, exact=True): + async def get_entity_rows_by_id(self, id, exact=True): try: if exact: return next((id, hash) for found_id, hash, _, _, _ @@ -184,7 +184,7 @@ class MemorySession(Session): except StopIteration: pass - def get_input_entity(self, key): + async def get_input_entity(self, key): try: if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd): # hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel')) @@ -204,21 +204,21 @@ class MemorySession(Session): if isinstance(key, str): phone = utils.parse_phone(key) if phone: - result = self.get_entity_rows_by_phone(phone) + result = await self.get_entity_rows_by_phone(phone) else: username, invite = utils.parse_username(key) if username and not invite: - result = self.get_entity_rows_by_username(username) + result = await self.get_entity_rows_by_username(username) else: tup = utils.resolve_invite_link(key)[1] if tup: - result = self.get_entity_rows_by_id(tup, exact=False) + result = await self.get_entity_rows_by_id(tup, exact=False) elif isinstance(key, int): - result = self.get_entity_rows_by_id(key, exact) + result = await self.get_entity_rows_by_id(key, exact) if not result and isinstance(key, str): - result = self.get_entity_rows_by_name(key) + result = await self.get_entity_rows_by_name(key) if result: entity_id, entity_hash = result # unpack resulting tuple @@ -233,14 +233,14 @@ class MemorySession(Session): else: raise ValueError('Could not find input entity with key ', key) - def cache_file(self, md5_digest, file_size, instance): + async def cache_file(self, md5_digest, file_size, instance): if not isinstance(instance, (InputDocument, InputPhoto)): raise TypeError('Cannot cache %s instance' % type(instance)) key = (md5_digest, file_size, _SentFileType.from_type(type(instance))) value = (instance.id, instance.access_hash) self._files[key] = value - def get_file(self, md5_digest, file_size, cls): + async def get_file(self, md5_digest, file_size, cls): key = (md5_digest, file_size, _SentFileType.from_type(cls)) try: return cls(*self._files[key])