Remove enqueuer abstraction from sender

Unnecessary complexity since Python lacks exclusive ownership.
This commit is contained in:
Lonami Exo 2023-09-01 11:57:41 +02:00
parent 77b49a1c88
commit 2e1321b6c9
2 changed files with 41 additions and 64 deletions

View File

@ -2,9 +2,9 @@ import asyncio
import struct import struct
import time import time
from abc import ABC from abc import ABC
from asyncio import FIRST_COMPLETED, Future, Queue, StreamReader, StreamWriter from asyncio import FIRST_COMPLETED, Event, Future, StreamReader, StreamWriter
from dataclasses import dataclass from dataclasses import dataclass
from typing import BinaryIO, Generic, List, Optional, Self, Tuple, TypeVar from typing import Generic, List, Optional, Self, TypeVar
from ..crypto.auth_key import AuthKey from ..crypto.auth_key import AuthKey
from ..mtproto import authentication from ..mtproto import authentication
@ -68,22 +68,6 @@ class Request(Generic[Return]):
result: Future[Return] result: Future[Return]
class Enqueuer:
__slots__ = ("_queue",)
def __init__(self, queue: Queue[Request[object]]) -> None:
self._queue = queue
def enqueue(self, request: RemoteCall[Return]) -> Future[Return]:
body = bytes(request)
assert len(body) >= 4
oneshot = asyncio.get_running_loop().create_future()
self._queue.put_nowait(
Request(body=body, state=NotSerialized(), result=oneshot)
)
return oneshot
@dataclass @dataclass
class Sender: class Sender:
_reader: StreamReader _reader: StreamReader
@ -92,38 +76,37 @@ class Sender:
_mtp: Mtp _mtp: Mtp
_mtp_buffer: bytearray _mtp_buffer: bytearray
_requests: List[Request[object]] _requests: List[Request[object]]
_request_rx: Queue[Request[object]] _request_event: Event
_next_ping: float _next_ping: float
_read_buffer: bytearray _read_buffer: bytearray
_write_drain_pending: bool _write_drain_pending: bool
@classmethod @classmethod
async def connect( async def connect(cls, transport: Transport, mtp: Mtp, addr: str) -> Self:
cls, transport: Transport, mtp: Mtp, addr: str
) -> Tuple[Self, Enqueuer]:
reader, writer = await asyncio.open_connection(*addr.split(":")) reader, writer = await asyncio.open_connection(*addr.split(":"))
request_queue: Queue[Request[object]] = Queue()
return ( return cls(
cls(
_reader=reader, _reader=reader,
_writer=writer, _writer=writer,
_transport=transport, _transport=transport,
_mtp=mtp, _mtp=mtp,
_mtp_buffer=bytearray(), _mtp_buffer=bytearray(),
_requests=[], _requests=[],
_request_rx=request_queue, _request_event=Event(),
_next_ping=asyncio.get_running_loop().time() + PING_DELAY, _next_ping=asyncio.get_running_loop().time() + PING_DELAY,
_read_buffer=bytearray(), _read_buffer=bytearray(),
_write_drain_pending=False, _write_drain_pending=False,
),
Enqueuer(request_queue),
) )
async def disconnect(self): async def disconnect(self) -> None:
self._writer.close() self._writer.close()
await self._writer.wait_closed() await self._writer.wait_closed()
def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]:
rx = self._enqueue_body(bytes(request))
self._request_event.set()
return rx
async def invoke(self, request: RemoteCall[Return]) -> bytes: async def invoke(self, request: RemoteCall[Return]) -> bytes:
rx = self._enqueue_body(bytes(request)) rx = self._enqueue_body(bytes(request))
return await self._step_until_receive(rx) return await self._step_until_receive(rx)
@ -146,7 +129,7 @@ class Sender:
async def step(self) -> List[Updates]: async def step(self) -> List[Updates]:
self._try_fill_write() self._try_fill_write()
recv_req = asyncio.create_task(self._request_rx.get()) recv_req = asyncio.create_task(self._request_event.wait())
recv_data = asyncio.create_task(self._reader.read(MAXIMUM_DATA)) recv_data = asyncio.create_task(self._reader.read(MAXIMUM_DATA))
send_data = asyncio.create_task(self._do_send()) send_data = asyncio.create_task(self._do_send())
done, pending = await asyncio.wait( done, pending = await asyncio.wait(
@ -161,7 +144,7 @@ class Sender:
result = [] result = []
if recv_req in done: if recv_req in done:
self._requests.append(recv_req.result()) self._request_event.clear()
if recv_data in done: if recv_data in done:
result = self._on_net_read(recv_data.result()) result = self._on_net_read(recv_data.result())
if send_data in done: if send_data in done:
@ -281,15 +264,12 @@ class Sender:
return None return None
async def connect(transport: Transport, addr: str) -> Tuple[Sender, Enqueuer]: async def connect(transport: Transport, addr: str) -> Sender:
sender, enqueuer = await Sender.connect(transport, Plain(), addr) sender = await Sender.connect(transport, Plain(), addr)
return await generate_auth_key(sender, enqueuer) return await generate_auth_key(sender)
async def generate_auth_key( async def generate_auth_key(sender: Sender) -> Sender:
sender: Sender,
enqueuer: Enqueuer,
) -> Tuple[Sender, Enqueuer]:
request, data1 = authentication.step1() request, data1 = authentication.step1()
response = await sender.send(request) response = await sender.send(request)
request, data2 = authentication.step2(data1, response) request, data2 = authentication.step2(data1, response)
@ -301,20 +281,17 @@ async def generate_auth_key(
time_offset = finished.time_offset time_offset = finished.time_offset
first_salt = finished.first_salt first_salt = finished.first_salt
return ( return Sender(
Sender(
_reader=sender._reader, _reader=sender._reader,
_writer=sender._writer, _writer=sender._writer,
_transport=sender._transport, _transport=sender._transport,
_mtp=Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt), _mtp=Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt),
_mtp_buffer=sender._mtp_buffer, _mtp_buffer=sender._mtp_buffer,
_requests=sender._requests, _requests=sender._requests,
_request_rx=sender._request_rx, _request_event=sender._request_event,
_next_ping=time.time() + PING_DELAY, _next_ping=time.time() + PING_DELAY,
_read_buffer=sender._read_buffer, _read_buffer=sender._read_buffer,
_write_drain_pending=sender._write_drain_pending, _write_drain_pending=sender._write_drain_pending,
),
enqueuer,
) )
@ -322,7 +299,7 @@ async def connect_with_auth(
transport: Transport, transport: Transport,
addr: str, addr: str,
auth_key: bytes, auth_key: bytes,
) -> Tuple[Sender, Enqueuer]: ) -> Sender:
return await Sender.connect( return await Sender.connect(
transport, Encrypted(AuthKey.from_bytes(auth_key)), addr transport, Encrypted(AuthKey.from_bytes(auth_key)), addr
) )

View File

@ -22,11 +22,11 @@ def test_invoke_encrypted_method(caplog: LogCaptureFixture) -> None:
def timeout() -> float: def timeout() -> float:
return deadline - asyncio.get_running_loop().time() return deadline - asyncio.get_running_loop().time()
sender, enqueuer = await asyncio.wait_for( sender = await asyncio.wait_for(
connect(Full(), TELEGRAM_DEFAULT_TEST_DC), timeout() connect(Full(), TELEGRAM_DEFAULT_TEST_DC), timeout()
) )
rx = enqueuer.enqueue( rx = sender.enqueue(
functions.invoke_with_layer( functions.invoke_with_layer(
layer=LAYER, layer=LAYER,
query=functions.init_connection( query=functions.init_connection(