diff --git a/telethon/client/uploads.py b/telethon/client/uploads.py index c46c2899..a9d16758 100644 --- a/telethon/client/uploads.py +++ b/telethon/client/uploads.py @@ -540,83 +540,43 @@ class UploadMethods: if isinstance(file, (types.InputFile, types.InputFileBig)): return file # Already uploaded - if not file_name and getattr(file, 'name', None): - file_name = file.name - - if file_size is not None: - pass # do nothing as it's already kwown - elif isinstance(file, str): - file_size = os.path.getsize(file) - stream = open(file, 'rb') - close_stream = True - elif isinstance(file, bytes): - file_size = len(file) - stream = io.BytesIO(file) - close_stream = True - else: - if not callable(getattr(file, 'read', None)): - raise TypeError('file description should have a `read` method') - - if callable(getattr(file, 'seekable', None)): - seekable = await helpers._maybe_await(file.seekable()) - else: - seekable = False - - if seekable: - pos = await helpers._maybe_await(file.tell()) - await helpers._maybe_await(file.seek(0, os.SEEK_END)) - file_size = await helpers._maybe_await(file.tell()) - await helpers._maybe_await(file.seek(pos, os.SEEK_SET)) - - stream = file - close_stream = False - else: - self._log[__name__].warning( - 'Could not determine file size beforehand so the entire ' - 'file will be read in-memory') - - data = await helpers._maybe_await(file.read()) - stream = io.BytesIO(data) - close_stream = True - file_size = len(data) - - # File will now either be a string or bytes - if not part_size_kb: - part_size_kb = utils.get_appropriated_part_size(file_size) - - if part_size_kb > 512: - raise ValueError('The part size must be less or equal to 512KB') - - part_size = int(part_size_kb * 1024) - if part_size % 1024 != 0: - raise ValueError( - 'The part size must be evenly divisible by 1024') - - # Set a default file name if None was specified - file_id = helpers.generate_random_long() - if not file_name: - if isinstance(file, str): - file_name = os.path.basename(file) - else: - file_name = str(file_id) - - # If the file name lacks extension, add it if possible. - # Else Telegram complains with `PHOTO_EXT_INVALID_ERROR` - # even if the uploaded image is indeed a photo. - if not os.path.splitext(file_name)[-1]: - file_name += utils._get_extension(file) - - # Determine whether the file is too big (over 10MB) or not - # Telegram does make a distinction between smaller or larger files - is_big = file_size > 10 * 1024 * 1024 - hash_md5 = hashlib.md5() - - part_count = (file_size + part_size - 1) // part_size - self._log[__name__].info('Uploading file of %d bytes in %d chunks of %d', - file_size, part_count, part_size) - pos = 0 - try: + async with helpers._FileStream(file, file_size=file_size) as stream: + # Opening the stream will determine the correct file size + file_size = stream.file_size + + if not part_size_kb: + part_size_kb = utils.get_appropriated_part_size(file_size) + + if part_size_kb > 512: + raise ValueError('The part size must be less or equal to 512KB') + + part_size = int(part_size_kb * 1024) + if part_size % 1024 != 0: + raise ValueError( + 'The part size must be evenly divisible by 1024') + + # Set a default file name if None was specified + file_id = helpers.generate_random_long() + if not file_name: + file_name = stream.name or str(file_id) + + # If the file name lacks extension, add it if possible. + # Else Telegram complains with `PHOTO_EXT_INVALID_ERROR` + # even if the uploaded image is indeed a photo. + if not os.path.splitext(file_name)[-1]: + file_name += utils._get_extension(stream) + + # Determine whether the file is too big (over 10MB) or not + # Telegram does make a distinction between smaller or larger files + is_big = file_size > 10 * 1024 * 1024 + hash_md5 = hashlib.md5() + + part_count = (file_size + part_size - 1) // part_size + self._log[__name__].info('Uploading file of %d bytes in %d chunks of %d', + file_size, part_count, part_size) + + pos = 0 for part_index in range(part_count): # Read the file by in chunks of size part_size part = await helpers._maybe_await(stream.read(part_size)) @@ -663,9 +623,6 @@ class UploadMethods: else: raise RuntimeError( 'Failed to upload file part {}.'.format(part_index)) - finally: - if close_stream: - await helpers._maybe_await(stream.close()) if is_big: return types.InputFileBig(file_id, part_count, file_name) diff --git a/telethon/helpers.py b/telethon/helpers.py index fd37487d..6c782b0b 100644 --- a/telethon/helpers.py +++ b/telethon/helpers.py @@ -1,9 +1,13 @@ """Various helpers not related to the Telegram API itself""" import asyncio +import io import enum import os import struct import inspect +import logging +import functools +from pathlib import Path from hashlib import sha1 @@ -13,6 +17,9 @@ class _EntityType(enum.Enum): CHANNEL = 2 +_log = logging.getLogger(__name__) + + # region Multiple utilities @@ -280,4 +287,105 @@ class TotalList(list): ', '.join(repr(x) for x in self), self.total) +class _FileStream(io.IOBase): + """ + Proxy around things that represent a file and need to be used as streams + which may or not need to be closed. + + This will handle `pathlib.Path`, `str` paths, in-memory `bytes`, and + anything IO-like (including `aiofiles`). + + It also provides access to the name and file size (also necessary). + """ + def __init__(self, file, *, file_size=None): + if isinstance(file, Path): + file = str(file.absolute()) + + self._file = file + self._name = None + self._size = file_size + self._stream = None + self._close_stream = None + + async def __aenter__(self): + if isinstance(self._file, str): + self._name = os.path.basename(self._file) + self._size = os.path.getsize(self._file) + self._stream = open(self._file, 'rb') + self._close_stream = True + + elif isinstance(self._file, bytes): + self._size = len(self._file) + self._stream = io.BytesIO(self._file) + self._close_stream = True + + elif not callable(getattr(self._file, 'read', None)): + raise TypeError('file description should have a `read` method') + + elif self._size is not None: + self._name = getattr(self._file, 'name', None) + self._stream = self._file + self._close_stream = False + + else: + if callable(getattr(self._file, 'seekable', None)): + seekable = await _maybe_await(self._file.seekable()) + else: + seekable = False + + if seekable: + pos = await _maybe_await(self._file.tell()) + await _maybe_await(self._file.seek(0, os.SEEK_END)) + self._size = await _maybe_await(self._file.tell()) + await _maybe_await(self._file.seek(pos, os.SEEK_SET)) + self._stream = self._file + self._close_stream = False + else: + _log.warning( + 'Could not determine file size beforehand so the entire ' + 'file will be read in-memory') + + data = await _maybe_await(self._file.read()) + self._size = len(data) + self._stream = io.BytesIO(data) + self._close_stream = True + + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self._close_stream and self._stream: + await _maybe_await(self._stream.close()) + + @property + def file_size(self): + return self._size + + @property + def name(self): + return self._name + + # Proxy all the methods. Doesn't need to be readable (makes multiline edits easier) + def read(self, *args, **kwargs): return self._stream.read(*args, **kwargs) + def readinto(self, *args, **kwargs): return self._stream.readinto(*args, **kwargs) + def write(self, *args, **kwargs): return self._stream.write(*args, **kwargs) + def fileno(self, *args, **kwargs): return self._stream.fileno(*args, **kwargs) + def flush(self, *args, **kwargs): return self._stream.flush(*args, **kwargs) + def isatty(self, *args, **kwargs): return self._stream.isatty(*args, **kwargs) + def readable(self, *args, **kwargs): return self._stream.readable(*args, **kwargs) + def readline(self, *args, **kwargs): return self._stream.readline(*args, **kwargs) + def readlines(self, *args, **kwargs): return self._stream.readlines(*args, **kwargs) + def seek(self, *args, **kwargs): return self._stream.seek(*args, **kwargs) + def seekable(self, *args, **kwargs): return self._stream.seekable(*args, **kwargs) + def tell(self, *args, **kwargs): return self._stream.tell(*args, **kwargs) + def truncate(self, *args, **kwargs): return self._stream.truncate(*args, **kwargs) + def writable(self, *args, **kwargs): return self._stream.writable(*args, **kwargs) + def writelines(self, *args, **kwargs): return self._stream.writelines(*args, **kwargs) + + # close is special because it will be called by __del__ but we do NOT + # want to close the file unless we have to (we're just a wrapper). + # Instead, we do nothing (we should be used through the decorator which + # has its own mechanism to close the file correctly). + def close(self, *args, **kwargs): + pass + # endregion diff --git a/telethon/utils.py b/telethon/utils.py index 86a8661e..8fbd9f48 100644 --- a/telethon/utils.py +++ b/telethon/utils.py @@ -611,7 +611,9 @@ def _get_metadata(file): # The parser may fail and we don't want to crash if # the extraction process fails. try: - # Note: aiofiles are intentionally left out for simplicity + # Note: aiofiles are intentionally left out for simplicity. + # `helpers._FileStream` is async only for simplicity too, so can't + # reuse it here. if isinstance(file, str): stream = open(file, 'rb') elif isinstance(file, bytes):