[V2] Upgrade ruff and mypy version, format files (#4474)

This commit is contained in:
Jahongir Qurbonov
2024-10-06 23:05:11 +05:00
committed by GitHub
parent 918f719ab2
commit 86d41e1f06
67 changed files with 177 additions and 118 deletions

View File

@@ -18,7 +18,7 @@ class InlineResults(metaclass=NoPublicConstructor):
bot: abcs.InputUser,
query: str,
peer: Optional[PeerRef],
):
) -> None:
self._client = client
self._bot = bot
self._query = query

View File

@@ -29,7 +29,7 @@ class ParticipantList(AsyncList[Participant]):
self,
client: Client,
peer: ChannelRef | GroupRef,
):
) -> None:
super().__init__()
self._client = client
self._peer = peer
@@ -106,7 +106,7 @@ class RecentActionList(AsyncList[RecentAction]):
self,
client: Client,
peer: ChannelRef | GroupRef,
):
) -> None:
super().__init__()
self._client = client
self._peer = peer
@@ -148,7 +148,7 @@ class ProfilePhotoList(AsyncList[File]):
self,
client: Client,
peer: PeerRef,
):
) -> None:
super().__init__()
self._client = client
self._peer = peer

View File

@@ -20,7 +20,7 @@ if TYPE_CHECKING:
class DialogList(AsyncList[Dialog]):
def __init__(self, client: Client):
def __init__(self, client: Client) -> None:
super().__init__()
self._client = client
self._offset = 0
@@ -93,7 +93,7 @@ async def delete_dialog(self: Client, dialog: Peer | PeerRef, /) -> None:
class DraftList(AsyncList[Draft]):
def __init__(self, client: Client):
def __init__(self, client: Client) -> None:
super().__init__()
self._client = client
self._offset = 0

View File

@@ -425,7 +425,7 @@ class FileBytesList(AsyncList[bytes]):
self,
client: Client,
file: File,
):
) -> None:
super().__init__()
self._client = client
self._loc = file._input_location()

View File

@@ -253,7 +253,7 @@ class HistoryList(MessageList):
*,
offset_id: int,
offset_date: int,
):
) -> None:
super().__init__()
self._client = client
self._peer = peer
@@ -323,7 +323,7 @@ class CherryPickedList(MessageList):
client: Client,
peer: PeerRef,
ids: list[int],
):
) -> None:
super().__init__()
self._client = client
self._peer = peer
@@ -367,7 +367,7 @@ class SearchList(MessageList):
query: str,
offset_id: int,
offset_date: int,
):
) -> None:
super().__init__()
self._client = client
self._peer = peer
@@ -434,7 +434,7 @@ class GlobalSearchList(MessageList):
query: str,
offset_id: int,
offset_date: int,
):
) -> None:
super().__init__()
self._client = client
self._limit = limit

View File

@@ -9,7 +9,7 @@ from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional, TypeVar
from ....version import __version__
from ...mtproto import BadStatus, Full, RpcError
from ...mtproto import BadStatusError, Full, RpcError
from ...mtsender import Connector, Sender
from ...mtsender import connect as do_connect_sender
from ...session import DataCenter
@@ -120,7 +120,7 @@ async def connect_sender(
),
)
)
except BadStatus as e:
except BadStatusError as e:
if e.status == 404 and auth:
dc = DataCenter(
id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None

View File

@@ -5,7 +5,7 @@ from collections.abc import Awaitable, Callable
from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Optional, Sequence, Type
from ...session import Gap
from ...session import GapError
from ...tl import abcs
from ..events import Continue, Event
from ..events.filters import FilterType
@@ -80,14 +80,14 @@ def process_socket_updates(client: Client, all_updates: list[abcs.Updates]) -> N
for updates in all_updates:
try:
client._message_box.ensure_known_peer_hashes(updates, client._chat_hashes)
except Gap:
except GapError:
return
try:
result, users, chats = client._message_box.process_updates(
updates, client._chat_hashes
)
except Gap:
except GapError:
return
extend_update_queue(client, result, users, chats)

View File

@@ -25,7 +25,7 @@ async def get_me(self: Client) -> Optional[User]:
class ContactList(AsyncList[User]):
def __init__(self, client: Client):
def __init__(self, client: Client) -> None:
super().__init__()
self._client = client

View File

@@ -46,7 +46,7 @@ class Raw(Event):
client: Client,
update: abcs.Update,
chat_map: dict[int, Peer],
):
) -> None:
self._client = client
self._raw = update
self._chat_map = chat_map

View File

@@ -25,7 +25,7 @@ class ButtonCallback(Event):
client: Client,
update: types.UpdateBotCallbackQuery,
chat_map: dict[int, Peer],
):
) -> None:
self._client = client
self._raw = update
self._chat_map = chat_map
@@ -101,7 +101,7 @@ class InlineQuery(Event):
Only bot accounts can receive this event.
"""
def __init__(self, update: types.UpdateBotInlineQuery):
def __init__(self, update: types.UpdateBotInlineQuery) -> None:
self._raw = update
@classmethod

View File

@@ -36,20 +36,20 @@ class HTMLToTelegramParser(HTMLParser):
self._open_tags_meta.appendleft(None)
attributes = dict(attrs)
EntityType: Optional[Type[MessageEntity]] = None
entity_type: Optional[Type[MessageEntity]] = None
args = {}
if tag == "strong" or tag == "b":
EntityType = MessageEntityBold
entity_type = MessageEntityBold
elif tag == "em" or tag == "i":
EntityType = MessageEntityItalic
entity_type = MessageEntityItalic
elif tag == "u":
EntityType = MessageEntityUnderline
entity_type = MessageEntityUnderline
elif tag == "del" or tag == "s":
EntityType = MessageEntityStrike
entity_type = MessageEntityStrike
elif tag == "blockquote":
EntityType = MessageEntityBlockquote
entity_type = MessageEntityBlockquote
elif tag == "details":
EntityType = MessageEntitySpoiler
entity_type = MessageEntitySpoiler
elif tag == "code":
try:
# If we're in the middle of a <pre> tag, this <code> tag is
@@ -63,9 +63,9 @@ class HTMLToTelegramParser(HTMLParser):
if cls := attributes.get("class"):
pre.language = cls[len("language-") :]
except KeyError:
EntityType = MessageEntityCode
entity_type = MessageEntityCode
elif tag == "pre":
EntityType = MessageEntityPre
entity_type = MessageEntityPre
args["language"] = ""
elif tag == "a":
url = attributes.get("href")
@@ -73,20 +73,20 @@ class HTMLToTelegramParser(HTMLParser):
return
if url.startswith("mailto:"):
url = url[len("mailto:") :]
EntityType = MessageEntityEmail
entity_type = MessageEntityEmail
else:
if self.get_starttag_text() == url:
EntityType = MessageEntityUrl
entity_type = MessageEntityUrl
else:
EntityType = MessageEntityTextUrl
entity_type = MessageEntityTextUrl
args["url"] = del_surrogate(url)
url = None
self._open_tags_meta.popleft()
self._open_tags_meta.appendleft(url)
if EntityType and tag not in self._building_entities:
Et = cast(Any, EntityType)
self._building_entities[tag] = Et(
if entity_type and tag not in self._building_entities:
any_entity_type = cast(Any, entity_type)
self._building_entities[tag] = any_entity_type(
offset=len(self.text),
# The length will be determined when closing the tag.
length=0,

View File

@@ -22,7 +22,7 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
This class is constructed by calling :meth:`telethon.Client.prepare_album`.
"""
def __init__(self, *, client: Client):
def __init__(self, *, client: Client) -> None:
self._client = client
self._medias: list[types.InputSingleMedia] = []

View File

@@ -122,7 +122,7 @@ class OutWrapper:
_fd: OutFileLike | BufferedWriter
_owned_fd: Optional[BufferedWriter]
def __init__(self, file: str | Path | OutFileLike):
def __init__(self, file: str | Path | OutFileLike) -> None:
if isinstance(file, str):
file = Path(file)
@@ -166,7 +166,7 @@ class File(metaclass=NoPublicConstructor):
thumbs: Optional[Sequence[abcs.PhotoSize]],
raw: Optional[abcs.MessageMedia | abcs.Photo | abcs.Document],
client: Optional[Client],
):
) -> None:
self._attributes = attributes
self._size = size
self._name = name

View File

@@ -25,7 +25,7 @@ class InlineResult(metaclass=NoPublicConstructor):
results: types.messages.BotResults,
result: types.BotInlineMediaResult | types.BotInlineResult,
default_peer: Optional[PeerRef],
):
) -> None:
self._client = client
self._raw_results = results
self._raw = result

View File

@@ -1,5 +1,5 @@
try:
import cryptg
import cryptg # type: ignore [import-untyped]
def ige_encrypt(
plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes
@@ -18,7 +18,7 @@ try:
)
except ImportError:
import pyaes
import pyaes # type: ignore [import-untyped]
def ige_encrypt(
plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes

View File

@@ -3,7 +3,7 @@ from .authentication import step1 as auth_step1
from .authentication import step2 as auth_step2
from .authentication import step3 as auth_step3
from .mtp import (
BadMessage,
BadMessageError,
Deserialization,
Encrypted,
MsgId,
@@ -13,7 +13,14 @@ from .mtp import (
RpcResult,
Update,
)
from .transport import Abridged, BadStatus, Full, Intermediate, MissingBytes, Transport
from .transport import (
Abridged,
BadStatusError,
Full,
Intermediate,
MissingBytesError,
Transport,
)
from .utils import DEFAULT_COMPRESSION_THRESHOLD
__all__ = [
@@ -25,7 +32,7 @@ __all__ = [
"auth_step1",
"auth_step2",
"auth_step3",
"BadMessage",
"BadMessageError",
"Deserialization",
"Encrypted",
"MsgId",
@@ -35,10 +42,10 @@ __all__ = [
"RpcResult",
"Update",
"Abridged",
"BadStatus",
"BadStatusError",
"Full",
"Intermediate",
"MissingBytes",
"MissingBytesError",
"Transport",
"DEFAULT_COMPRESSION_THRESHOLD",
]

View File

@@ -310,4 +310,4 @@ def check_new_nonce_hash(got: int, expected: int) -> None:
def check_g_in_range(value: int, low: int, high: int) -> None:
if not (low < value < high):
raise ValueError(f"g parameter {value} not in range({low+1}, {high})")
raise ValueError(f"g parameter {value} not in range({low + 1}, {high})")

View File

@@ -1,11 +1,19 @@
from .encrypted import Encrypted
from .plain import Plain
from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError, RpcResult, Update
from .types import (
BadMessageError,
Deserialization,
MsgId,
Mtp,
RpcError,
RpcResult,
Update,
)
__all__ = [
"Encrypted",
"Plain",
"BadMessage",
"BadMessageError",
"Deserialization",
"MsgId",
"Mtp",

View File

@@ -60,7 +60,15 @@ from ..utils import (
gzip_decompress,
message_requires_ack,
)
from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError, RpcResult, Update
from .types import (
BadMessageError,
Deserialization,
MsgId,
Mtp,
RpcError,
RpcResult,
Update,
)
NUM_FUTURE_SALTS = 64
@@ -269,7 +277,7 @@ class Encrypted(Mtp):
bad_msg = AbcBadMsgNotification.from_bytes(message.body)
assert isinstance(bad_msg, (BadServerSalt, BadMsgNotification))
exc = BadMessage(msg_id=MsgId(bad_msg.bad_msg_id), code=bad_msg.error_code)
exc = BadMessageError(msg_id=MsgId(bad_msg.bad_msg_id), code=bad_msg.error_code)
if bad_msg.bad_msg_id == self._salt_request_msg_id:
# Response to internal request, do not propagate.

View File

@@ -15,7 +15,7 @@ class Update:
__slots__ = ("body",)
def __init__(self, body: bytes | bytearray | memoryview):
def __init__(self, body: bytes | bytearray | memoryview) -> None:
self.body = body
@@ -26,7 +26,7 @@ class RpcResult:
__slots__ = ("msg_id", "body")
def __init__(self, msg_id: MsgId, body: bytes | bytearray | memoryview):
def __init__(self, msg_id: MsgId, body: bytes | bytearray | memoryview) -> None:
self.msg_id = msg_id
self.body = body
@@ -142,7 +142,7 @@ RETRYABLE_MSG_IDS = {16, 17, 48}
NON_FATAL_MSG_IDS = RETRYABLE_MSG_IDS & {32, 33}
class BadMessage(ValueError):
class BadMessageError(ValueError):
def __init__(
self,
*args: object,
@@ -178,7 +178,7 @@ class BadMessage(ValueError):
return self._code == other._code
Deserialization = Update | RpcResult | RpcError | BadMessage
Deserialization = Update | RpcResult | RpcError | BadMessageError
# https://core.telegram.org/mtproto/description

View File

@@ -1,6 +1,13 @@
from .abcs import BadStatus, MissingBytes, Transport
from .abcs import BadStatusError, MissingBytesError, Transport
from .abridged import Abridged
from .full import Full
from .intermediate import Intermediate
__all__ = ["BadStatus", "MissingBytes", "Transport", "Abridged", "Full", "Intermediate"]
__all__ = [
"BadStatusError",
"MissingBytesError",
"Transport",
"Abridged",
"Full",
"Intermediate",
]

View File

@@ -16,12 +16,12 @@ class Transport(ABC):
pass
class MissingBytes(ValueError):
class MissingBytesError(ValueError):
def __init__(self, *, expected: int, got: int) -> None:
super().__init__(f"missing bytes, expected: {expected}, got: {got}")
class BadStatus(ValueError):
class BadStatusError(ValueError):
def __init__(self, *, status: int) -> None:
super().__init__(f"transport reported bad status: {status}")
self.status = status

View File

@@ -1,6 +1,6 @@
import struct
from .abcs import BadStatus, MissingBytes, OutFn, Transport
from .abcs import BadStatusError, MissingBytesError, OutFn, Transport
class Abridged(Transport):
@@ -38,25 +38,25 @@ class Abridged(Transport):
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
if not input:
raise MissingBytes(expected=1, got=0)
raise MissingBytesError(expected=1, got=0)
length = input[0]
if 1 < length < 127:
header_len = 1
elif len(input) < 4:
raise MissingBytes(expected=4, got=len(input))
raise MissingBytesError(expected=4, got=len(input))
else:
header_len = 4
length = struct.unpack_from("<i", input)[0] >> 8
if length <= 0:
if length < 0:
raise BadStatus(status=-length)
raise BadStatusError(status=-length)
raise ValueError(f"bad length, expected > 0, got: {length}")
length *= 4
if len(input) < header_len + length:
raise MissingBytes(expected=header_len + length, got=len(input))
raise MissingBytesError(expected=header_len + length, got=len(input))
output += memoryview(input)[header_len : header_len + length]
return header_len + length

View File

@@ -1,7 +1,7 @@
import struct
from zlib import crc32
from .abcs import BadStatus, MissingBytes, OutFn, Transport
from .abcs import BadStatusError, MissingBytesError, OutFn, Transport
class Full(Transport):
@@ -37,17 +37,17 @@ class Full(Transport):
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
if len(input) < 4:
raise MissingBytes(expected=4, got=len(input))
raise MissingBytesError(expected=4, got=len(input))
length = struct.unpack_from("<i", input)[0]
assert isinstance(length, int)
if length < 12:
if length < 0:
raise BadStatus(status=-length)
raise BadStatusError(status=-length)
raise ValueError(f"bad length, expected > 12, got: {length}")
if len(input) < length:
raise MissingBytes(expected=length, got=len(input))
raise MissingBytesError(expected=length, got=len(input))
seq = struct.unpack_from("<i", input, 4)[0]
if seq != self._recv_seq:

View File

@@ -1,6 +1,6 @@
import struct
from .abcs import BadStatus, MissingBytes, OutFn, Transport
from .abcs import BadStatusError, MissingBytesError, OutFn, Transport
class Intermediate(Transport):
@@ -34,19 +34,19 @@ class Intermediate(Transport):
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
if len(input) < 4:
raise MissingBytes(expected=4, got=len(input))
raise MissingBytesError(expected=4, got=len(input))
length = struct.unpack_from("<i", input)[0]
assert isinstance(length, int)
if len(input) < length:
raise MissingBytes(expected=length, got=len(input))
raise MissingBytesError(expected=length, got=len(input))
if length <= 4:
if (
length >= 4
and (status := struct.unpack("<i", input[4 : 4 + length])[0]) < 0
):
raise BadStatus(status=-status)
raise BadStatusError(status=-status)
raise ValueError(f"bad length, expected > 0, got: {length}")

View File

@@ -10,9 +10,9 @@ from typing import Generic, Optional, Protocol, Self, Type, TypeVar
from ..crypto import AuthKey
from ..mtproto import (
BadMessage,
BadMessageError,
Encrypted,
MissingBytes,
MissingBytesError,
MsgId,
Mtp,
Plain,
@@ -133,7 +133,7 @@ class NotSerialized(RequestState):
class Serialized(RequestState):
__slots__ = ("msg_id", "container_msg_id")
def __init__(self, msg_id: MsgId):
def __init__(self, msg_id: MsgId) -> None:
self.msg_id = msg_id
self.container_msg_id = msg_id
@@ -141,7 +141,7 @@ class Serialized(RequestState):
class Sent(RequestState):
__slots__ = ("msg_id", "container_msg_id")
def __init__(self, msg_id: MsgId, container_msg_id: MsgId):
def __init__(self, msg_id: MsgId, container_msg_id: MsgId) -> None:
self.msg_id = msg_id
self.container_msg_id = container_msg_id
@@ -298,7 +298,7 @@ class Sender:
self._mtp_buffer.clear()
try:
n = self._transport.unpack(self._read_buffer, self._mtp_buffer)
except MissingBytes:
except MissingBytesError:
break
else:
del self._read_buffer[:n]
@@ -403,7 +403,7 @@ class Sender:
result,
)
def _process_bad_message(self, result: BadMessage) -> None:
def _process_bad_message(self, result: BadMessageError) -> None:
for req in self._drain_requests(result.msg_id):
if result.retryable:
self._logger.log(

View File

@@ -11,7 +11,7 @@ from .message_box import (
BOT_CHANNEL_DIFF_LIMIT,
NO_UPDATES_TIMEOUT,
USER_CHANNEL_DIFF_LIMIT,
Gap,
GapError,
MessageBox,
PossibleGap,
PrematureEndReason,
@@ -32,7 +32,7 @@ __all__ = [
"BOT_CHANNEL_DIFF_LIMIT",
"NO_UPDATES_TIMEOUT",
"USER_CHANNEL_DIFF_LIMIT",
"Gap",
"GapError",
"MessageBox",
"PossibleGap",
"PrematureEndReason",

View File

@@ -9,7 +9,7 @@ PeerRefType: TypeAlias = Type[UserRef] | Type[ChannelRef] | Type[GroupRef]
class ChatHashCache:
__slots__ = ("_hash_map", "_self_id", "_self_bot")
def __init__(self, self_user: Optional[tuple[int, bool]]):
def __init__(self, self_user: Optional[tuple[int, bool]]) -> None:
self._hash_map: dict[int, tuple[PeerRefType, int]] = {}
self._self_id = self_user[0] if self_user else None
self._self_bot = self_user[1] if self_user else False

View File

@@ -2,7 +2,7 @@ from .defs import (
BOT_CHANNEL_DIFF_LIMIT,
NO_UPDATES_TIMEOUT,
USER_CHANNEL_DIFF_LIMIT,
Gap,
GapError,
PossibleGap,
PrematureEndReason,
PtsInfo,
@@ -14,7 +14,7 @@ __all__ = [
"BOT_CHANNEL_DIFF_LIMIT",
"NO_UPDATES_TIMEOUT",
"USER_CHANNEL_DIFF_LIMIT",
"Gap",
"GapError",
"PossibleGap",
"PrematureEndReason",
"PtsInfo",

View File

@@ -2,7 +2,7 @@ from typing import Optional
from ...tl import abcs, types
from ..chat import ChatHashCache
from .defs import ENTRY_ACCOUNT, ENTRY_SECRET, NO_SEQ, Gap, PtsInfo
from .defs import ENTRY_ACCOUNT, ENTRY_SECRET, NO_SEQ, GapError, PtsInfo
def updates_(updates: types.Updates) -> types.UpdatesCombined:
@@ -147,7 +147,7 @@ def update_short_sent_message(
def adapt(updates: abcs.Updates, chat_hashes: ChatHashCache) -> types.UpdatesCombined:
if isinstance(updates, types.UpdatesTooLong):
raise Gap
raise GapError
elif isinstance(updates, types.UpdateShortMessage):
return update_short_message(updates, chat_hashes.self_id)
elif isinstance(updates, types.UpdateShortChatMessage):

View File

@@ -80,6 +80,6 @@ class PrematureEndReason(Enum):
BANNED = "ban"
class Gap(ValueError):
class GapError(ValueError):
def __repr__(self) -> str:
return "Gap()"

View File

@@ -20,7 +20,7 @@ from .defs import (
POSSIBLE_GAP_TIMEOUT,
USER_CHANNEL_DIFF_LIMIT,
Entry,
Gap,
GapError,
PossibleGap,
PrematureEndReason,
State,
@@ -252,7 +252,7 @@ class MessageBox:
)
if can_recover:
self.try_begin_get_diff(ENTRY_ACCOUNT, "missing hash")
raise Gap
raise GapError
# https://core.telegram.org/api/updates
def process_updates(
@@ -281,7 +281,7 @@ class MessageBox:
return result, combined.users, combined.chats
elif self.seq + 1 < combined.seq_start:
self.try_begin_get_diff(ENTRY_ACCOUNT, "detected gap")
raise Gap
raise GapError
def update_sort_key(update: abcs.Update) -> int:
pts = pts_info_from_update(update)

View File

@@ -168,7 +168,7 @@ class Session:
dcs: Optional[list[DataCenter]] = None,
user: Optional[User] = None,
state: Optional[UpdateState] = None,
):
) -> None:
self.dcs = dcs or []
"List of known data-centers."
self.user = user

View File

@@ -13,7 +13,7 @@ class MemorySession(Storage):
__slots__ = ("session",)
def __init__(self, session: Optional[Session] = None):
def __init__(self, session: Optional[Session] = None) -> None:
self.session = session
async def load(self) -> Optional[Session]:

View File

@@ -20,7 +20,7 @@ class SqliteSession(Storage):
an VCS by accident (adding ``*.session`` to ``.gitignore`` will catch them).
"""
def __init__(self, file: str | Path):
def __init__(self, file: str | Path) -> None:
path = Path(file)
if not path.suffix:
path = path.with_suffix(EXTENSION)

View File

@@ -26,11 +26,11 @@ def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]:
raise RuntimeError(
"generated api and mtproto schemas cannot have colliding constructor identifiers"
)
ALL_TYPES = API_TYPES | MTPROTO_TYPES
all_types = API_TYPES | MTPROTO_TYPES
# Signatures don't fully match, but this is a private method
# and all previous uses are compatible with `dict.get`.
Reader._get_ty = ALL_TYPES.get # type: ignore [assignment]
Reader._get_ty = all_types.get # type: ignore [assignment]
return Reader._get_ty(constructor_id)

View File

@@ -17,9 +17,9 @@ def _bootstrap_get_deserializer(
raise RuntimeError(
"generated api and mtproto schemas cannot have colliding constructor identifiers"
)
ALL_DESER = API_DESER | MTPROTO_DESER
all_deser = API_DESER | MTPROTO_DESER
Request._get_deserializer = ALL_DESER.get # type: ignore [assignment]
Request._get_deserializer = all_deser.get # type: ignore [assignment]
return Request._get_deserializer(constructor_id)