This commit is contained in:
Jahongir Qurbonov 2025-06-03 05:27:52 +00:00 committed by GitHub
commit 0eaf2f5fac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 426 additions and 47 deletions

View File

@ -10,6 +10,7 @@ from typing_extensions import Self
from ....version import __version__ as default_version from ....version import __version__ as default_version
from ...mtsender import Connector, Sender from ...mtsender import Connector, Sender
from ...mtsender.reconnection import NoReconnect, ReconnectionPolicy
from ...session import ( from ...session import (
ChannelRef, ChannelRef,
ChatHashCache, ChatHashCache,
@ -215,6 +216,7 @@ class Client:
lang_code: Optional[str] = None, lang_code: Optional[str] = None,
datacenter: Optional[DataCenter] = None, datacenter: Optional[DataCenter] = None,
connector: Optional[Connector] = None, connector: Optional[Connector] = None,
reconnection_policy: Optional[ReconnectionPolicy] = None,
) -> None: ) -> None:
assert __package__ assert __package__
base_logger = logger or logging.getLogger(__package__[: __package__.index(".")]) base_logger = logger or logging.getLogger(__package__[: __package__.index(".")])
@ -246,6 +248,7 @@ class Client:
update_queue_limit=update_queue_limit, update_queue_limit=update_queue_limit,
base_logger=base_logger, base_logger=base_logger,
connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)), connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)),
reconnection_policy=reconnection_policy or NoReconnect(),
) )
self._session = Session() self._session = Session()
@ -253,9 +256,9 @@ class Client:
self._message_box = MessageBox(base_logger=base_logger) self._message_box = MessageBox(base_logger=base_logger)
self._chat_hashes = ChatHashCache(None) self._chat_hashes = ChatHashCache(None)
self._last_update_limit_warn: Optional[float] = None self._last_update_limit_warn: Optional[float] = None
self._updates: asyncio.Queue[ self._updates: asyncio.Queue[tuple[abcs.Update, dict[int, Peer]]] = (
tuple[abcs.Update, dict[int, Peer]] asyncio.Queue(maxsize=self._config.update_queue_limit or 0)
] = asyncio.Queue(maxsize=self._config.update_queue_limit or 0) )
self._dispatcher: Optional[asyncio.Task[None]] = None self._dispatcher: Optional[asyncio.Task[None]] = None
self._handlers: dict[ self._handlers: dict[
Type[Event], Type[Event],

View File

@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Optional, TypeVar
from ....version import __version__ from ....version import __version__
from ...mtproto import BadStatusError, Full, RpcError from ...mtproto import BadStatusError, Full, RpcError
from ...mtsender import Connector, Sender from ...mtsender import Connector, ReconnectionPolicy, Sender
from ...mtsender import connect as do_connect_sender from ...mtsender import connect as do_connect_sender
from ...session import DataCenter from ...session import DataCenter
from ...session import User as SessionUser from ...session import User as SessionUser
@ -46,6 +46,7 @@ class Config:
api_hash: str api_hash: str
base_logger: logging.Logger base_logger: logging.Logger
connector: Connector connector: Connector
reconnection_policy: ReconnectionPolicy
device_model: str = field(default_factory=default_device_model) device_model: str = field(default_factory=default_device_model)
system_version: str = field(default_factory=default_system_version) system_version: str = field(default_factory=default_system_version)
app_version: str = __version__ app_version: str = __version__
@ -100,6 +101,7 @@ async def connect_sender(
auth_key=auth, auth_key=auth,
base_logger=config.base_logger, base_logger=config.base_logger,
connector=config.connector, connector=config.connector,
reconnection_policy=config.reconnection_policy,
) )
try: try:

View File

@ -7,6 +7,24 @@ from .aes import ige_decrypt, ige_encrypt
from .auth_key import AuthKey from .auth_key import AuthKey
class InvalidBufferError(ValueError):
def __init__(self) -> None:
super().__init__("invalid ciphertext buffer length")
class AuthKeyMismatchError(ValueError):
def __init__(self) -> None:
super().__init__("server authkey mismatches with ours")
class MsgKeyMismatchError(ValueError):
def __init__(self) -> None:
super().__init__("server msgkey mismatches with ours")
CryptoError = InvalidBufferError | AuthKeyMismatchError | MsgKeyMismatchError
# "where x = 0 for messages from client to server and x = 8 for those from server to client" # "where x = 0 for messages from client to server and x = 8 for those from server to client"
class Side(IntEnum): class Side(IntEnum):
CLIENT = 0 CLIENT = 0
@ -77,14 +95,14 @@ def decrypt_data_v2(
x = int(side) x = int(side)
if len(ciphertext) < 24 or (len(ciphertext) - 24) % 16 != 0: if len(ciphertext) < 24 or (len(ciphertext) - 24) % 16 != 0:
raise ValueError("invalid ciphertext buffer length") raise InvalidBufferError()
# salt, session_id and sequence_number should also be checked. # salt, session_id and sequence_number should also be checked.
# However, not doing so has worked fine for years. # However, not doing so has worked fine for years.
key_id = ciphertext[:8] key_id = ciphertext[:8]
if auth_key.key_id != key_id: if auth_key.key_id != key_id:
raise ValueError("server authkey mismatches with ours") raise AuthKeyMismatchError()
msg_key = ciphertext[8:24] msg_key = ciphertext[8:24]
key, iv = calc_key(auth_key, msg_key, side) key, iv = calc_key(auth_key, msg_key, side)
@ -93,7 +111,7 @@ def decrypt_data_v2(
# https://core.telegram.org/mtproto/security_guidelines#mtproto-encrypted-messages # https://core.telegram.org/mtproto/security_guidelines#mtproto-encrypted-messages
our_key = sha256(auth_key.data[x + 88 : x + 120] + plaintext).digest() our_key = sha256(auth_key.data[x + 88 : x + 120] + plaintext).digest()
if msg_key != our_key[8:24]: if msg_key != our_key[8:24]:
raise ValueError("server msgkey mismatches with ours") raise MsgKeyMismatchError()
return plaintext return plaintext

View File

@ -62,11 +62,15 @@ from ..utils import (
) )
from .types import ( from .types import (
BadMessageError, BadMessageError,
DecompressionFailed,
Deserialization, Deserialization,
DeserializationFailure,
MsgBufferTooSmall,
MsgId, MsgId,
Mtp, Mtp,
RpcError, RpcError,
RpcResult, RpcResult,
UnexpectedConstructor,
Update, Update,
) )
@ -85,6 +89,7 @@ UPDATE_IDS = {
AffectedFoundMessages.constructor_id(), AffectedFoundMessages.constructor_id(),
AffectedHistory.constructor_id(), AffectedHistory.constructor_id(),
AffectedMessages.constructor_id(), AffectedMessages.constructor_id(),
# TODO InvitedUsers
} }
HEADER_LEN = 8 + 8 # salt, client_id HEADER_LEN = 8 + 8 # salt, client_id
@ -151,7 +156,7 @@ class Encrypted(Mtp):
self._last_msg_id: int self._last_msg_id: int
self._in_pending_ack: list[int] = [] self._in_pending_ack: list[int] = []
self._msg_count: int self._msg_count: int
self._reset_session() self.reset()
@property @property
def auth_key(self) -> bytes: def auth_key(self) -> bytes:
@ -166,13 +171,6 @@ class Encrypted(Mtp):
def _adjusted_now(self) -> float: def _adjusted_now(self) -> float:
return time.time() + self._time_offset return time.time() + self._time_offset
def _reset_session(self) -> None:
self._client_id = struct.unpack("<q", os.urandom(8))[0]
self._sequence = 0
self._last_msg_id = 0
self._in_pending_ack.clear()
self._msg_count = 0
def _get_new_msg_id(self) -> int: def _get_new_msg_id(self) -> int:
new_msg_id = int(self._adjusted_now() * 0x100000000) new_msg_id = int(self._adjusted_now() * 0x100000000)
if self._last_msg_id >= new_msg_id: if self._last_msg_id >= new_msg_id:
@ -245,12 +243,38 @@ class Encrypted(Mtp):
result = rpc_result.result result = rpc_result.result
msg_id = MsgId(req_msg_id) msg_id = MsgId(req_msg_id)
inner_constructor = struct.unpack_from("<I", result)[0]
try:
inner_constructor = struct.unpack_from("<I", result)[0]
except struct.error:
# If the result is empty, we can't unpack it.
# This can happen if the server returns an empty response.
logging.exception("failed to unpack inner_constructor")
self._deserialization.append(
DeserializationFailure(
msg_id=msg_id,
error=MsgBufferTooSmall(),
)
)
return
if inner_constructor == GeneratedRpcError.constructor_id(): if inner_constructor == GeneratedRpcError.constructor_id():
error = RpcError._from_mtproto_error(GeneratedRpcError.from_bytes(result)) try:
error.msg_id = msg_id error = RpcError._from_mtproto_error(
self._deserialization.append(error) GeneratedRpcError.from_bytes(result)
)
error.msg_id = msg_id
self._deserialization.append(error)
except Exception:
logging.exception("failed to deserialize error")
self._deserialization.append(
DeserializationFailure(
msg_id=msg_id,
error=UnexpectedConstructor(
id=inner_constructor,
),
)
)
elif inner_constructor == RpcAnswerUnknown.constructor_id(): elif inner_constructor == RpcAnswerUnknown.constructor_id():
pass # msg_id = rpc_drop_answer.msg_id pass # msg_id = rpc_drop_answer.msg_id
elif inner_constructor == RpcAnswerDroppedRunning.constructor_id(): elif inner_constructor == RpcAnswerDroppedRunning.constructor_id():
@ -258,9 +282,15 @@ class Encrypted(Mtp):
elif inner_constructor == RpcAnswerDropped.constructor_id(): elif inner_constructor == RpcAnswerDropped.constructor_id():
pass # dropped pass # dropped
elif inner_constructor == GzipPacked.constructor_id(): elif inner_constructor == GzipPacked.constructor_id():
body = gzip_decompress(GzipPacked.from_bytes(result)) try:
self._store_own_updates(body) body = gzip_decompress(GzipPacked.from_bytes(result))
self._deserialization.append(RpcResult(msg_id, body)) self._store_own_updates(body)
self._deserialization.append(RpcResult(msg_id, body))
except Exception:
logging.exception("failed to decompress response")
self._deserialization.append(
DeserializationFailure(msg_id=msg_id, error=DecompressionFailed())
)
else: else:
self._store_own_updates(result) self._store_own_updates(result)
self._deserialization.append(RpcResult(msg_id, result)) self._deserialization.append(RpcResult(msg_id, result))
@ -300,7 +330,7 @@ class Encrypted(Mtp):
elif bad_msg.error_code in (16, 17): elif bad_msg.error_code in (16, 17):
self._correct_time_offset(message.msg_id) self._correct_time_offset(message.msg_id)
elif bad_msg.error_code in (32, 33): elif bad_msg.error_code in (32, 33):
self._reset_session() self.reset()
else: else:
raise exc raise exc
@ -365,6 +395,9 @@ class Encrypted(Mtp):
for inner_message in container.messages: for inner_message in container.messages:
self._process_message(inner_message) self._process_message(inner_message)
def _handle_msg_copy(self, message: Message) -> None:
raise RuntimeError("msg_copy should not be used")
def _handle_gzip_packed(self, message: Message) -> None: def _handle_gzip_packed(self, message: Message) -> None:
container = GzipPacked.from_bytes(message.body) container = GzipPacked.from_bytes(message.body)
inner_body = gzip_decompress(container) inner_body = gzip_decompress(container)
@ -459,3 +492,11 @@ class Encrypted(Mtp):
result = self._deserialization[:] result = self._deserialization[:]
self._deserialization.clear() self._deserialization.clear()
return result return result
def reset(self) -> None:
self._client_id = struct.unpack("<q", os.urandom(8))[0]
self._sequence = 0
self._last_msg_id = 0
self._in_pending_ack.clear()
self._msg_count = 0
self._salt_request_msg_id = None

View File

@ -2,7 +2,16 @@ import struct
from typing import Optional from typing import Optional
from ..utils import check_message_buffer from ..utils import check_message_buffer
from .types import Deserialization, MsgId, Mtp, RpcResult from .types import (
BadAuthKeyError,
BadMsgIdError,
Deserialization,
MsgId,
Mtp,
NegativeLengthError,
RpcResult,
TooLongMsgError,
)
class Plain(Mtp): class Plain(Mtp):
@ -38,18 +47,19 @@ class Plain(Mtp):
auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload) auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload)
if auth_key_id != 0: if auth_key_id != 0:
raise ValueError(f"bad auth key, expected: 0, got: {auth_key_id}") raise BadAuthKeyError(got=auth_key_id, expected=0)
# https://core.telegram.org/mtproto/description#message-identifier-msg-id # https://core.telegram.org/mtproto/description#message-identifier-msg-id
if msg_id <= 0 or (msg_id % 4) != 1: if msg_id <= 0 or (msg_id % 4) != 1:
raise ValueError(f"bad msg id, got: {msg_id}") raise BadMsgIdError(got=msg_id)
if length < 0: if length < 0:
raise ValueError(f"bad length: expected >= 0, got: {length}") raise NegativeLengthError(got=length)
if 20 + length > len(payload): if 20 + length > (lp := len(payload)):
raise ValueError( raise TooLongMsgError(got=length, max_length=lp - 20)
f"message too short, expected: {20 + length}, got {len(payload)}"
)
return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))] return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))]
def reset(self) -> None:
self._buffer.clear()

View File

@ -5,6 +5,7 @@ from typing import NewType, Optional
from typing_extensions import Self from typing_extensions import Self
from ...crypto.crypto import CryptoError
from ...tl.mtproto.types import RpcError as GeneratedRpcError from ...tl.mtproto.types import RpcError as GeneratedRpcError
MsgId = NewType("MsgId", int) MsgId = NewType("MsgId", int)
@ -180,7 +181,112 @@ class BadMessageError(ValueError):
return self._code == other._code return self._code == other._code
Deserialization = Update | RpcResult | RpcError | BadMessageError # Deserialization errors are not fatal, so we don't subclass RpcError.
class BadAuthKeyError(ValueError):
def __init__(self, *args: object, got: int, expected: int) -> None:
super().__init__(f"bad server auth key (got {got}, expected {expected})", *args)
self._got = got
self._expected = expected
@property
def got(self):
return self._got
@property
def expected(self):
return self._expected
class BadMsgIdError(ValueError):
def __init__(self, *args: object, got: int) -> None:
super().__init__(f"bad server message id (got {got})", *args)
self._got = got
@property
def got(self):
return self._got
class NegativeLengthError(ValueError):
def __init__(self, *args: object, got: int) -> None:
super().__init__(f"bad server message length (got {got})", *args)
self._got = got
@property
def got(self):
return self._got
class TooLongMsgError(ValueError):
def __init__(self, *args: object, got: int, max_length: int) -> None:
super().__init__(
f"bad server message length (got {got}, when at most it should be {max_length})",
*args,
)
self._got = got
self._expected = max_length
@property
def got(self):
return self._got
@property
def expected(self):
return self._expected
class MsgBufferTooSmall(ValueError):
def __init__(self, *args: object) -> None:
super().__init__(
"server responded with a payload that's too small to fit a valid message",
*args,
)
class DecompressionFailed(ValueError):
def __init__(self, *args: object) -> None:
super().__init__("failed to decompress server's data", *args)
class UnexpectedConstructor(ValueError):
def __init__(self, *args: object, id: int) -> None:
super().__init__(f"unexpected constructor: {id:08x}", *args)
class DecryptionError(ValueError):
def __init__(self, *args: object, error: CryptoError) -> None:
super().__init__(f"failed to decrypt message: {error}", *args)
self._error = error
@property
def error(self):
return self._error
DeserializationError = (
BadAuthKeyError
| BadMsgIdError
| NegativeLengthError
| TooLongMsgError
| MsgBufferTooSmall
| DecompressionFailed
| UnexpectedConstructor
| DecryptionError
)
class DeserializationFailure:
__slots__ = ("msg_id", "error")
def __init__(self, msg_id: MsgId, error: DeserializationError) -> None:
self.msg_id = msg_id
self.error = error
Deserialization = (
Update | RpcResult | RpcError | BadMessageError | DeserializationFailure
)
# https://core.telegram.org/mtproto/description # https://core.telegram.org/mtproto/description
@ -209,3 +315,9 @@ class Mtp(ABC):
""" """
Deserialize incoming buffer payload. Deserialize incoming buffer payload.
""" """
@abstractmethod
def reset(self) -> None:
"""
Reset the internal buffer.
"""

View File

@ -15,13 +15,37 @@ class Transport(ABC):
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int: def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
pass pass
@abstractmethod
def reset(self):
pass
class MissingBytesError(ValueError): class MissingBytesError(ValueError):
def __init__(self, *, expected: int, got: int) -> None: def __init__(self, *, expected: int, got: int) -> None:
super().__init__(f"missing bytes, expected: {expected}, got: {got}") super().__init__(f"missing bytes, expected: {expected}, got: {got}")
class BadLenError(ValueError):
def __init__(self, *, got: int) -> None:
super().__init__(f"bad len (got {got})")
class BadSeqError(ValueError):
def __init__(self, *, expected: int, got: int) -> None:
super().__init__(f"bad seq (expected {expected}, got {got})")
class BadCrcError(ValueError):
def __init__(self, *, expected: int, got: int) -> None:
super().__init__(f"bad crc (expected {expected}, got {got})")
class BadStatusError(ValueError): class BadStatusError(ValueError):
def __init__(self, *, status: int) -> None: def __init__(self, *, status: int) -> None:
super().__init__(f"transport reported bad status: {status}") super().__init__(f"bad status (negative length -{status})")
self.status = status self.status = status
TransportError = (
MissingBytesError | BadLenError | BadSeqError | BadCrcError | BadStatusError
)

View File

@ -1,3 +1,4 @@
import logging
import struct import struct
from .abcs import BadStatusError, MissingBytesError, OutFn, Transport from .abcs import BadStatusError, MissingBytesError, OutFn, Transport
@ -60,3 +61,7 @@ class Abridged(Transport):
output += memoryview(input)[header_len : header_len + length] output += memoryview(input)[header_len : header_len + length]
return header_len + length return header_len + length
def reset(self):
logging.info("resetting sending of header in abridged transport")
self._init = False

View File

@ -1,3 +1,4 @@
import logging
import struct import struct
from zlib import crc32 from zlib import crc32
@ -61,3 +62,8 @@ class Full(Transport):
self._recv_seq += 1 self._recv_seq += 1
output += memoryview(input)[8 : length - 4] output += memoryview(input)[8 : length - 4]
return length return length
def reset(self):
logging.info("resetting recv and send seqs in full transport")
self._send_seq = 0
self._recv_seq = 0

View File

@ -1,3 +1,4 @@
import logging
import struct import struct
from .abcs import BadStatusError, MissingBytesError, OutFn, Transport from .abcs import BadStatusError, MissingBytesError, OutFn, Transport
@ -52,3 +53,7 @@ class Intermediate(Transport):
output += memoryview(input)[4 : 4 + length] output += memoryview(input)[4 : 4 + length]
return length + 4 return length + 4
def reset(self):
logging.info("resetting sending of header in intermediate transport")
self._init = False

View File

@ -1,3 +1,4 @@
from .reconnection import ReconnectionPolicy
from .sender import ( from .sender import (
MAXIMUM_DATA, MAXIMUM_DATA,
NO_PING_DISCONNECT, NO_PING_DISCONNECT,
@ -18,4 +19,5 @@ __all__ = [
"Connector", "Connector",
"Sender", "Sender",
"connect", "connect",
"ReconnectionPolicy",
] ]

View File

@ -0,0 +1,6 @@
from struct import error as struct_error
from ..mtproto.mtp.types import DeserializationError
from ..mtproto.transport.abcs import TransportError
ReadError = struct_error | TransportError | DeserializationError

View File

@ -0,0 +1,36 @@
from abc import ABC, abstractmethod
class ReconnectionPolicy(ABC):
"""
Base class for reconnection policies.
This class defines the interface for reconnection policies used by the MTSender.
It allows for custom reconnection strategies to be implemented by subclasses.
"""
@abstractmethod
def should_retry(self, attempts: int) -> bool | float:
"""
Determines whether the client should retry the connection attempt.
"""
pass
class NoReconnect(ReconnectionPolicy):
def should_retry(self, attempts: int) -> bool:
return False
class FixedReconnect(ReconnectionPolicy):
__slots__ = ("max_attempts", "delay")
def __init__(self, attempts: int, delay: float):
self.max_attempts = attempts
self.delay = delay
def should_retry(self, attempts: int) -> bool | float:
if attempts < self.max_attempts:
return self.delay
return False

View File

@ -24,12 +24,14 @@ from ..mtproto import (
Update, Update,
authentication, authentication,
) )
from ..mtproto.mtp.types import DeserializationFailure
from ..tl import Request as RemoteCall from ..tl import Request as RemoteCall
from ..tl.abcs import Updates from ..tl.abcs import Updates
from ..tl.core import Serializable from ..tl.core import Serializable
from ..tl.mtproto.functions import ping_delay_disconnect from ..tl.mtproto.functions import ping_delay_disconnect
from ..tl.types import UpdateDeleteMessages, UpdateShort from ..tl.types import UpdateDeleteMessages, UpdateShort, UpdatesTooLong
from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages
from .reconnection import ReconnectionPolicy
MAXIMUM_DATA = (1024 * 1024) + (8 * 1024) MAXIMUM_DATA = (1024 * 1024) + (8 * 1024)
@ -162,6 +164,8 @@ class Request(Generic[Return]):
class Sender: class Sender:
dc_id: int dc_id: int
addr: str addr: str
_connector: Connector
_reconnection_policy: ReconnectionPolicy
_logger: logging.Logger _logger: logging.Logger
_reader: AsyncReader _reader: AsyncReader
_writer: AsyncWriter _writer: AsyncWriter
@ -175,6 +179,7 @@ class Sender:
_requests: list[Request[object]] _requests: list[Request[object]]
_next_ping: float _next_ping: float
_read_buffer: bytearray _read_buffer: bytearray
_write_drain_pending: bool
@classmethod @classmethod
async def connect( async def connect(
@ -185,6 +190,7 @@ class Sender:
addr: str, addr: str,
*, *,
connector: Connector, connector: Connector,
reconnection_policy: ReconnectionPolicy,
base_logger: logging.Logger, base_logger: logging.Logger,
) -> Self: ) -> Self:
ip, port = addr.split(":") ip, port = addr.split(":")
@ -193,6 +199,8 @@ class Sender:
return cls( return cls(
dc_id=dc_id, dc_id=dc_id,
addr=addr, addr=addr,
_connector=connector,
_reconnection_policy=reconnection_policy,
_logger=base_logger.getChild("mtsender"), _logger=base_logger.getChild("mtsender"),
_reader=reader, _reader=reader,
_writer=writer, _writer=writer,
@ -206,6 +214,7 @@ class Sender:
_requests=[], _requests=[],
_next_ping=asyncio.get_running_loop().time() + PING_DELAY, _next_ping=asyncio.get_running_loop().time() + PING_DELAY,
_read_buffer=bytearray(), _read_buffer=bytearray(),
_write_drain_pending=False,
) )
async def disconnect(self) -> None: async def disconnect(self) -> None:
@ -236,15 +245,21 @@ class Sender:
if rx.done(): if rx.done():
return rx.result() return rx.result()
async def step(self) -> None: async def step(self):
try:
await self._step()
except Exception as error:
await self._on_error(error)
async def _step(self) -> None:
if not self._writing: if not self._writing:
self._writing = True self._writing = True
await self._do_write() await self._do_send()
self._writing = False self._writing = False
if not self._reading: if not self._reading:
self._reading = True self._reading = True
await self._do_read() await self._do_recv()
self._reading = False self._reading = False
else: else:
await self._step_done.wait() await self._step_done.wait()
@ -254,7 +269,7 @@ class Sender:
self._updates.clear() self._updates.clear()
return updates return updates
async def _do_read(self) -> None: async def _do_recv(self) -> None:
self._step_done.clear() self._step_done.clear()
timeout = self._next_ping - asyncio.get_running_loop().time() timeout = self._next_ping - asyncio.get_running_loop().time()
@ -266,10 +281,46 @@ class Sender:
else: else:
self._on_net_read(recv_data) self._on_net_read(recv_data)
finally: finally:
self._try_timeout_ping() self._try_ping_timeout()
self._step_done.set() self._step_done.set()
async def _do_write(self) -> None: async def _do_send(self) -> None:
self._try_fill_write()
if self._write_drain_pending:
await self._writer.drain()
self._on_net_write()
async def _try_connect(self):
attempts = 0
ip, port = self.addr.split(":")
while True:
try:
self._reader, self._writer = await self._connector(ip, int(port))
self._logger.info(
f"auto-reconnect success after {attempts} failed attempt(s)"
)
return
except Exception as e:
attempts += 1
self._logger.warning(f"auto-reconnect failed {attempts} time(s): {e!r}")
await asyncio.sleep(1)
delay = self._reconnection_policy.should_retry(attempts)
if delay:
if delay is not True:
await asyncio.sleep(delay)
continue
else:
self._logger.error(
f"auto-reconnect failed {attempts} time(s); giving up"
)
raise
def _try_fill_write(self) -> None:
if not self._requests: if not self._requests:
return return
@ -283,15 +334,14 @@ class Sender:
result = self._mtp.finalize() result = self._mtp.finalize()
if result: if result:
container_msg_id, mtp_buffer = result container_msg_id, mtp_buffer = result
self._transport.pack(mtp_buffer, self._writer.write)
await self._writer.drain()
for request in self._requests: for request in self._requests:
if isinstance(request.state, Serialized): if isinstance(request.state, Serialized):
request.state = Sent(request.state.msg_id, container_msg_id) request.state.container_msg_id = container_msg_id
def _try_timeout_ping(self) -> None: self._transport.pack(mtp_buffer, self._writer.write)
self._write_drain_pending = True
def _try_ping_timeout(self) -> None:
current_time = asyncio.get_running_loop().time() current_time = asyncio.get_running_loop().time()
if current_time >= self._next_ping: if current_time >= self._next_ping:
@ -321,6 +371,45 @@ class Sender:
del self._read_buffer[:n] del self._read_buffer[:n]
self._process_mtp_buffer() self._process_mtp_buffer()
def _on_net_write(self) -> None:
for req in self._requests:
if isinstance(req.state, Serialized):
req.state = Sent(req.state.msg_id, req.state.container_msg_id)
async def _on_error(self, error: Exception) -> None:
self._logger.info(f"handling error: {error}")
self._transport.reset()
self._mtp.reset()
self._logger.info(
"resetting sender state from read_buffer {}, mtp_buffer {}".format(
len(self._read_buffer),
len(self._mtp_buffer),
)
)
self._read_buffer.clear()
self._mtp_buffer.clear()
if isinstance(error, struct.error) and self._reconnection_policy.should_retry(
0
):
self._logger.info(f"read error occurred: {error}")
await self._try_connect()
for req in self._requests:
req.state = NotSerialized()
self._updates.append(UpdatesTooLong())
return
self._logger.warning(
f"marking all {len(self._requests)} request(s) as failed: {error}"
)
for req in self._requests:
req.result.set_exception(error)
raise error
def _process_mtp_buffer(self) -> None: def _process_mtp_buffer(self) -> None:
results = self._mtp.deserialize(self._mtp_buffer) results = self._mtp.deserialize(self._mtp_buffer)
@ -331,8 +420,14 @@ class Sender:
self._process_result(result) self._process_result(result)
elif isinstance(result, RpcError): elif isinstance(result, RpcError):
self._process_error(result) self._process_error(result)
else: elif isinstance(result, BadMessageError):
self._process_bad_message(result) self._process_bad_message(result)
elif isinstance(result, DeserializationFailure):
self._process_deserialize_error(result)
else:
raise RuntimeError(
f"unexpected result type {type(result).__name__}: {result}"
)
def _process_update(self, update: bytes | bytearray | memoryview) -> None: def _process_update(self, update: bytes | bytearray | memoryview) -> None:
try: try:
@ -421,6 +516,17 @@ class Sender:
result._caused_by = struct.unpack_from("<I", req.body)[0] result._caused_by = struct.unpack_from("<I", req.body)[0]
req.result.set_exception(result) req.result.set_exception(result)
def _process_deserialize_error(self, failure: DeserializationFailure):
req = self._pop_request(failure.msg_id)
if req:
self._logger.debug(f"got deserialization failure {failure.error}")
req.result.set_exception(failure.error)
else:
self._logger.info(
f"got deserialization failure {failure.error} but no such request is saved"
)
def _pop_request(self, msg_id: MsgId) -> Optional[Request[object]]: def _pop_request(self, msg_id: MsgId) -> Optional[Request[object]]:
for i, req in enumerate(self._requests): for i, req in enumerate(self._requests):
if isinstance(req.state, Serialized) and req.state.msg_id == msg_id: if isinstance(req.state, Serialized) and req.state.msg_id == msg_id:
@ -459,6 +565,7 @@ async def connect(
auth_key: Optional[bytes], auth_key: Optional[bytes],
base_logger: logging.Logger, base_logger: logging.Logger,
connector: Connector, connector: Connector,
reconnection_policy: ReconnectionPolicy,
) -> Sender: ) -> Sender:
if auth_key is None: if auth_key is None:
sender = await Sender.connect( sender = await Sender.connect(
@ -467,6 +574,7 @@ async def connect(
dc_id, dc_id,
addr, addr,
connector=connector, connector=connector,
reconnection_policy=reconnection_policy,
base_logger=base_logger, base_logger=base_logger,
) )
return await generate_auth_key(sender) return await generate_auth_key(sender)
@ -477,6 +585,7 @@ async def connect(
dc_id, dc_id,
addr, addr,
connector=connector, connector=connector,
reconnection_policy=reconnection_policy,
base_logger=base_logger, base_logger=base_logger,
) )