Port mtproto from grammers

This commit is contained in:
Lonami Exo
2023-07-09 21:16:55 +02:00
parent 9636ef35c1
commit 269ee4f05f
35 changed files with 1747 additions and 57 deletions

View File

@@ -29,35 +29,41 @@ def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]:
class Reader:
__slots__ = ("_buffer", "_pos", "_view")
__slots__ = ("_view", "_pos", "_len")
def __init__(self, buffer: bytes) -> None:
self._buffer = buffer
self._view = (
memoryview(buffer) if not isinstance(buffer, memoryview) else buffer
)
self._pos = 0
self._view = memoryview(self._buffer)
self._len = len(self._view)
def read_remaining(self) -> bytes:
return self.read(self._len - self._pos)
def read(self, n: int) -> bytes:
self._pos += n
return self._view[self._pos - n : n]
assert self._pos <= self._len
return self._view[self._pos - n : self._pos]
def read_fmt(self, fmt: str, size: int) -> tuple[Any, ...]:
assert struct.calcsize(fmt) == size
self._pos += size
assert self._pos <= self._len
return struct.unpack(fmt, self._view[self._pos - size : self._pos])
def read_bytes(self) -> bytes:
if self._buffer[self._pos] == 254:
if self._view[self._pos] == 254:
self._pos += 4
(length,) = struct.unpack(
"<i", self._buffer[self._pos - 3 : self._pos] + b"\0"
)
length = struct.unpack("<i", self._view[self._pos - 4 : self._pos])[0] >> 8
padding = length % 4
else:
length = self._buffer[self._pos]
length = self._view[self._pos]
padding = (length + 1) % 4
self._pos += 1
self._pos += length
assert self._pos <= self._len
data = self._view[self._pos - length : self._pos]
if padding > 0:
self._pos += 4 - padding
@@ -72,6 +78,7 @@ class Reader:
# Unfortunately `typing.cast` would add a tiny amount of runtime overhead
# which cannot be removed with optimization enabled.
self._pos += 4
assert self._pos <= self._len
cid = struct.unpack("<I", self._view[self._pos - 4 : self._pos])[0]
ty = self._get_ty(cid)
if ty is None:

View File

@@ -1,16 +1,13 @@
import struct
class Request:
__slots__ = "_body"
def __init__(self, body: bytes):
self._body = body
class Request(bytes):
__slots__ = ()
@property
def constructor_id(self) -> int:
try:
cid = struct.unpack("<i", self._body[:4])[0]
cid = struct.unpack("<i", self[:4])[0]
assert isinstance(cid, int)
return cid
except struct.error:

View File

@@ -35,7 +35,12 @@ class Serializable(abc.ABC):
return bytes(buffer)
def __repr__(self) -> str:
attrs = ", ".join(repr(getattr(self, attr)) for attr in self.__slots__)
fields = ((attr, getattr(self, attr)) for attr in self.__slots__)
fields = (
(name, bytes(field) if isinstance(field, memoryview) else field)
for name, field in fields
)
attrs = ", ".join(f"{name}={field!r}" for name, field in fields)
return f"{self.__class__.__name__}({attrs})"
def __eq__(self, other: object) -> bool: