diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 10ec0382..fee27620 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -163,6 +163,7 @@ class Request(Generic[Return]): class Sender: dc_id: int addr: str + _connector: Connector _logger: logging.Logger _reader: AsyncReader _writer: AsyncWriter @@ -176,6 +177,7 @@ class Sender: _requests: list[Request[object]] _next_ping: float _read_buffer: bytearray + _write_drain_pending: bool @classmethod async def connect( @@ -194,6 +196,7 @@ class Sender: return cls( dc_id=dc_id, addr=addr, + _connector=connector, _logger=base_logger.getChild("mtsender"), _reader=reader, _writer=writer, @@ -207,6 +210,7 @@ class Sender: _requests=[], _next_ping=asyncio.get_running_loop().time() + PING_DELAY, _read_buffer=bytearray(), + _write_drain_pending=False, ) async def disconnect(self) -> None: @@ -237,28 +241,31 @@ class Sender: if rx.done(): return rx.result() - async def step(self) -> None: + async def step(self): try: - if not self._writing: - self._writing = True - await self._do_write() - self._writing = False + await self._step() + except Exception as error: + self._on_error(error) - if not self._reading: - self._reading = True - await self._do_read() - self._reading = False - else: - await self._step_done.wait() - except Exception as e: - self._on_error(e) + async def _step(self) -> None: + if not self._writing: + self._writing = True + await self._do_send() + self._writing = False + + if not self._reading: + self._reading = True + await self._do_recv() + self._reading = False + else: + await self._step_done.wait() def pop_updates(self) -> list[Updates]: updates = self._updates[:] self._updates.clear() return updates - async def _do_read(self) -> None: + async def _do_recv(self) -> None: self._step_done.clear() timeout = self._next_ping - asyncio.get_running_loop().time() @@ -270,10 +277,31 @@ class Sender: else: self._on_net_read(recv_data) finally: - self._try_timeout_ping() + self._try_ping_timeout() self._step_done.set() - async def _do_write(self) -> None: + async def _do_send(self) -> None: + self._try_fill_write() + + if self._write_drain_pending: + await self._writer.drain() + self._on_net_write() + + async def try_connect(self): + # attempts = 0 + + ip, port = self.addr.split(":") + + while True: + try: + self._reader, self._writer = await self._connector(ip, int(port)) + break + except Exception as e: + logging.exception(e) + # TODO: reconnection_policy + break + + def _try_fill_write(self) -> None: if not self._requests: return @@ -287,15 +315,14 @@ class Sender: result = self._mtp.finalize() if result: container_msg_id, mtp_buffer = result - - self._transport.pack(mtp_buffer, self._writer.write) - await self._writer.drain() - for request in self._requests: if isinstance(request.state, Serialized): - request.state = Sent(request.state.msg_id, container_msg_id) + request.state.container_msg_id = container_msg_id - def _try_timeout_ping(self) -> None: + self._transport.pack(mtp_buffer, self._writer.write) + self._write_drain_pending = True + + def _try_ping_timeout(self) -> None: current_time = asyncio.get_running_loop().time() if current_time >= self._next_ping: @@ -325,6 +352,11 @@ class Sender: del self._read_buffer[:n] self._process_mtp_buffer() + def _on_net_write(self) -> None: + for req in self._requests: + if isinstance(req.state, Serialized): + req.state = Sent(req.state.msg_id, req.state.container_msg_id) + def _on_error(self, error: Exception): logging.info(f"Handling error: {error}") self._transport.reset() @@ -364,7 +396,9 @@ class Sender: elif isinstance(result, DeserializationFailure): self._process_deserialize_error(result) else: - raise RuntimeError(f"Unexpected result: {result}") + raise RuntimeError( + f"Unexpected result type {type(result).__name__!r}: {result}" + ) def _process_update(self, update: bytes | bytearray | memoryview) -> None: try: @@ -477,7 +511,6 @@ class Sender: def _drain_requests(self, msg_id: MsgId) -> Iterator[Request[object]]: for i in reversed(range(len(self._requests))): req = self._requests[i] - if isinstance(req.state, Serialized) and ( req.state.msg_id == msg_id or req.state.container_msg_id == msg_id ):