From 0e5ea59ecf2f050c67b91cb4d8b5e4d46c75fa81 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Fri, 18 Oct 2024 19:07:45 +0200 Subject: [PATCH] Use asyncio.BufferedProtocol in sender --- .../src/telethon/_impl/mtsender/__init__.py | 2 - .../src/telethon/_impl/mtsender/protocol.py | 59 +++++++++++++ client/src/telethon/_impl/mtsender/sender.py | 87 +++++++------------ 3 files changed, 88 insertions(+), 60 deletions(-) create mode 100644 client/src/telethon/_impl/mtsender/protocol.py diff --git a/client/src/telethon/_impl/mtsender/__init__.py b/client/src/telethon/_impl/mtsender/__init__.py index bc9d723f..4ba3c76d 100644 --- a/client/src/telethon/_impl/mtsender/__init__.py +++ b/client/src/telethon/_impl/mtsender/__init__.py @@ -1,5 +1,4 @@ from .sender import ( - MAXIMUM_DATA, NO_PING_DISCONNECT, PING_DELAY, AsyncReader, @@ -10,7 +9,6 @@ from .sender import ( ) __all__ = [ - "MAXIMUM_DATA", "NO_PING_DISCONNECT", "PING_DELAY", "AsyncReader", diff --git a/client/src/telethon/_impl/mtsender/protocol.py b/client/src/telethon/_impl/mtsender/protocol.py new file mode 100644 index 00000000..7d803b9e --- /dev/null +++ b/client/src/telethon/_impl/mtsender/protocol.py @@ -0,0 +1,59 @@ +import asyncio +from ..mtproto import ( + MissingBytesError, + Transport, +) + +MAXIMUM_DATA = (1024 * 1024) + (8 * 1024) + + +class BufferedTransportProtocol(asyncio.BufferedProtocol): + __slots__ = ( + "_transport", + "_buffer", + "_buffer_head", + "_packets", + "_output", + "_closed", + ) + + def __init__(self, transport: Transport): + self._transport = transport + self._buffer = bytearray(MAXIMUM_DATA) + self._buffer_head = 0 + self._packets: asyncio.Queue[bytes] = asyncio.Queue() + self._output = bytearray() + self._closed = asyncio.Event() + + # Method overrides + + def get_buffer(self, sizehint): + return self._buffer + + def buffer_updated(self, nbytes): + self._buffer_head += nbytes + while self._buffer_head: + self._output.clear() + try: + n = self._transport.unpack( + memoryview(self._buffer)[: self._buffer_head], self._output + ) + except MissingBytesError as e: + print(e) + return + else: + del self._buffer[:n] + self._buffer += bytes(n) + self._buffer_head -= n + self._packets.put_nowait(bytes(self._output)) + + def connection_lost(self, exc): + self._closed.set() + + # Custom methods + + def wait_closed(self): + return self._closed.wait() + + def wait_packet(self): + return self._packets.get() diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 35dabfd1..9513adb4 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -10,11 +10,11 @@ from typing import Generic, Optional, Protocol, Type, TypeVar from typing_extensions import Self +from .protocol import BufferedTransportProtocol from ..crypto import AuthKey from ..mtproto import ( BadMessageError, Encrypted, - MissingBytesError, MsgId, Mtp, Plain, @@ -31,7 +31,6 @@ from ..tl.mtproto.functions import ping_delay_disconnect from ..tl.types import UpdateDeleteMessages, UpdateShort from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages -MAXIMUM_DATA = (1024 * 1024) + (8 * 1024) PING_DELAY = 60 @@ -164,16 +163,13 @@ class Sender: addr: str lock: Lock _logger: logging.Logger - _reader: AsyncReader - _writer: AsyncWriter + _connection: asyncio.Transport _transport: Transport + _protocol: BufferedTransportProtocol _mtp: Mtp - _mtp_buffer: bytearray _requests: list[Request[object]] _request_event: Event _next_ping: float - _read_buffer: bytearray - _write_drain_pending: bool _step_counter: int @classmethod @@ -188,29 +184,29 @@ class Sender: base_logger: logging.Logger, ) -> Self: ip, port = addr.split(":") - reader, writer = await connector(ip, int(port)) + # TODO BRING BACK SUPPORT FOR connector + connection, protocol = await asyncio.get_running_loop().create_connection( + lambda: BufferedTransportProtocol(transport), ip, int(port) + ) return cls( dc_id=dc_id, addr=addr, lock=Lock(), _logger=base_logger.getChild("mtsender"), - _reader=reader, - _writer=writer, + _connection=connection, _transport=transport, + _protocol=protocol, _mtp=mtp, - _mtp_buffer=bytearray(), _requests=[], _request_event=Event(), _next_ping=asyncio.get_running_loop().time() + PING_DELAY, - _read_buffer=bytearray(), - _write_drain_pending=False, _step_counter=0, ) async def disconnect(self) -> None: - self._writer.close() - await self._writer.wait_closed() + self._connection.close() + await self._protocol.wait_closed() def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]: rx = self._enqueue_body(bytes(request)) @@ -251,14 +247,20 @@ class Sender: async def _step(self) -> list[Updates]: self._try_fill_write() + self._connection.resume_reading() recv_req = asyncio.create_task(self._request_event.wait()) - recv_data = asyncio.create_task(self._reader.read(MAXIMUM_DATA)) - send_data = asyncio.create_task(self._do_send()) + recv_data = asyncio.create_task(self._protocol.wait_packet()) + conn_lost = asyncio.create_task(self._protocol.wait_closed()) done, pending = await asyncio.wait( - (recv_req, recv_data, send_data), + ( + recv_req, + recv_data, + conn_lost, + ), timeout=self._next_ping - asyncio.get_running_loop().time(), return_when=FIRST_COMPLETED, ) + self._connection.pause_reading() if pending: for task in pending: @@ -270,24 +272,13 @@ class Sender: self._request_event.clear() if recv_data in done: result = self._on_net_read(recv_data.result()) - if send_data in done: - self._on_net_write() + if conn_lost in done: + raise ConnectionResetError if not done: self._on_ping_timeout() return result - async def _do_send(self) -> None: - if self._write_drain_pending: - await self._writer.drain() - self._write_drain_pending = False - else: - # Never return - await asyncio.get_running_loop().create_future() - def _try_fill_write(self) -> None: - if self._write_drain_pending: - return - for request in self._requests: if isinstance(request.state, NotSerialized): if (msg_id := self._mtp.push(request.body)) is not None: @@ -298,37 +289,17 @@ class Sender: result = self._mtp.finalize() if result: container_msg_id, mtp_buffer = result + + self._transport.pack(mtp_buffer, self._connection.write) for request in self._requests: if isinstance(request.state, Serialized): - request.state.container_msg_id = container_msg_id - - self._transport.pack(mtp_buffer, self._writer.write) - self._write_drain_pending = True - - def _on_net_read(self, read_buffer: bytes) -> list[Updates]: - if not read_buffer: - raise ConnectionResetError("read 0 bytes") - - self._read_buffer += read_buffer + request.state = Sent(request.state.msg_id, container_msg_id) + def _on_net_read(self, mtp_buffer: bytes) -> list[Updates]: updates: list[Updates] = [] - while self._read_buffer: - self._mtp_buffer.clear() - try: - n = self._transport.unpack(self._read_buffer, self._mtp_buffer) - except MissingBytesError: - break - else: - del self._read_buffer[:n] - self._process_mtp_buffer(updates) - + self._process_mtp_buffer(mtp_buffer, updates) return updates - def _on_net_write(self) -> None: - for req in self._requests: - if isinstance(req.state, Serialized): - req.state = Sent(req.state.msg_id, req.state.container_msg_id) - def _on_ping_timeout(self) -> None: ping_id = generate_random_id() self._enqueue_body( @@ -340,8 +311,8 @@ class Sender: ) self._next_ping = asyncio.get_running_loop().time() + PING_DELAY - def _process_mtp_buffer(self, updates: list[Updates]) -> None: - results = self._mtp.deserialize(self._mtp_buffer) + def _process_mtp_buffer(self, mtp_buffer: bytes, updates: list[Updates]) -> None: + results = self._mtp.deserialize(mtp_buffer) for result in results: if isinstance(result, Update):