From c864ef7e16565c08362e3d01b8a60ecb3bf6c31c Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Thu, 24 Sep 2020 10:03:28 +0200 Subject: [PATCH] Refetch msg if fileref expires while downloading docs Closes #1301. --- telethon/client/downloads.py | 95 ++++++++++++++++++++++++++++++++--- telethon/tl/custom/message.py | 2 + 2 files changed, 89 insertions(+), 8 deletions(-) diff --git a/telethon/client/downloads.py b/telethon/client/downloads.py index ba612197..d0fe2bac 100644 --- a/telethon/client/downloads.py +++ b/telethon/client/downloads.py @@ -26,7 +26,7 @@ MAX_CHUNK_SIZE = 512 * 1024 class _DirectDownloadIter(RequestIter): async def _init( - self, file, dc_id, offset, stride, chunk_size, request_size, file_size + self, file, dc_id, offset, stride, chunk_size, request_size, file_size, msg_data ): self.request = functions.upload.GetFileRequest( file, offset=offset, limit=request_size) @@ -35,6 +35,7 @@ class _DirectDownloadIter(RequestIter): self._stride = stride self._chunk_size = chunk_size self._last_part = None + self._msg_data = msg_data self._exported = dc_id and self.client.session.dc_id != dc_id if not self._exported: @@ -80,6 +81,29 @@ class _DirectDownloadIter(RequestIter): self._exported = True return await self._request() + except errors.FilerefUpgradeNeededError as e: + # Only implemented for documents which are the ones that may take that long to download + if not self._msg_data \ + or not isinstance(self.request.location, types.InputDocumentFileLocation) \ + or self.request.location.thumb_size != '': + raise + + self.client._log[__name__].info('File ref expired during download; refetching message') + chat, msg_id = self._msg_data + msg = await self.client.get_messages(chat, ids=msg_id) + + if not isinstance(msg.media, types.MessageMediaDocument): + raise + + document = msg.media.document + + # Message media may have been edited for something else + if document.id != self.request.location.id: + raise + + self.request.location.file_reference = document.file_reference + return await self._request() + async def close(self): if not self._sender: return @@ -344,10 +368,16 @@ class DownloadMethods: await client.download_media(message, progress_callback=callback) """ + # Downloading large documents may be slow enough to require a new file reference + # to be obtained mid-download. Store (input chat, message id) so that the message + # can be re-fetched. + msg_data = None + # TODO This won't work for messageService if isinstance(message, types.Message): date = message.date media = message.media + msg_data = (message.input_chat, message.id) if message.input_chat else None else: date = datetime.datetime.now() media = message @@ -365,7 +395,7 @@ class DownloadMethods: ) elif isinstance(media, (types.MessageMediaDocument, types.Document)): return await self._download_document( - media, file, date, thumb, progress_callback + media, file, date, thumb, progress_callback, msg_data ) elif isinstance(media, types.MessageMediaContact) and thumb is None: return self._download_contact( @@ -439,6 +469,29 @@ class DownloadMethods: data = await client.download_file(input_file, bytes) print(data[:16]) """ + return await self._download_file( + input_location, + file, + part_size_kb=part_size_kb, + file_size=file_size, + progress_callback=progress_callback, + dc_id=dc_id, + key=key, + iv=iv, + ) + + async def _download_file( + self: 'TelegramClient', + input_location: 'hints.FileLike', + file: 'hints.OutFileLike' = None, + *, + part_size_kb: float = None, + file_size: int = None, + progress_callback: 'hints.ProgressCallback' = None, + dc_id: int = None, + key: bytes = None, + iv: bytes = None, + msg_data: tuple = None) -> typing.Optional[bytes]: if not part_size_kb: if not file_size: part_size_kb = 64 # Reasonable default @@ -464,8 +517,8 @@ class DownloadMethods: f = file try: - async for chunk in self.iter_download( - input_location, request_size=part_size, dc_id=dc_id): + async for chunk in self._iter_download( + input_location, request_size=part_size, dc_id=dc_id, msg_data=msg_data): if iv and key: chunk = AES.decrypt_ige(chunk, key, iv) r = f.write(chunk) @@ -582,6 +635,30 @@ class DownloadMethods: await stream.close() assert len(header) == 32 """ + return self._iter_download( + file, + offset=offset, + stride=stride, + limit=limit, + chunk_size=chunk_size, + request_size=request_size, + file_size=file_size, + dc_id=dc_id, + ) + + def _iter_download( + self: 'TelegramClient', + file: 'hints.FileLike', + *, + offset: int = 0, + stride: int = None, + limit: int = None, + chunk_size: int = None, + request_size: int = MAX_CHUNK_SIZE, + file_size: int = None, + dc_id: int = None, + msg_data: tuple = None + ): info = utils._get_file_info(file) if info.dc_id is not None: dc_id = info.dc_id @@ -628,7 +705,8 @@ class DownloadMethods: stride=stride, chunk_size=chunk_size, request_size=request_size, - file_size=file_size + file_size=file_size, + msg_data=msg_data, ) # endregion @@ -748,7 +826,7 @@ class DownloadMethods: return kind, possible_names async def _download_document( - self, document, file, date, thumb, progress_callback): + self, document, file, date, thumb, progress_callback, msg_data): """Specialized version of .download_media() for documents.""" if isinstance(document, types.MessageMediaDocument): document = document.document @@ -768,7 +846,7 @@ class DownloadMethods: if isinstance(size, (types.PhotoCachedSize, types.PhotoStrippedSize)): return self._download_cached_photo_size(size, file) - result = await self.download_file( + result = await self._download_file( types.InputDocumentFileLocation( id=document.id, access_hash=document.access_hash, @@ -777,7 +855,8 @@ class DownloadMethods: ), file, file_size=size.size if size else document.size, - progress_callback=progress_callback + progress_callback=progress_callback, + msg_data=msg_data, ) return result if file is bytes else file diff --git a/telethon/tl/custom/message.py b/telethon/tl/custom/message.py index 5f286a40..820c1fbd 100644 --- a/telethon/tl/custom/message.py +++ b/telethon/tl/custom/message.py @@ -761,6 +761,8 @@ class Message(ChatGetter, SenderGetter, TLObject, abc.ABC): with the ``message`` already set. """ if self._client: + # Passing the entire message is important, in case it has to be + # refetched for a fresh file reference. return await self._client.download_media(self, *args, **kwargs) async def click(self, i=None, j=None,