diff --git a/DEVELOPING.md b/DEVELOPING.md index 342e776d..b248bea0 100644 --- a/DEVELOPING.md +++ b/DEVELOPING.md @@ -1,5 +1,13 @@ +Code generation: + ```sh pip install -e generator/ -python -m telethon_generator.codegen api.tl telethon/src/_impl/tl -python -m telethon_generator.codegen mtproto.tl telethon/src/_impl/tl/mtproto +python -m telethon_generator.codegen api.tl client/src/telethon/_impl/tl +python -m telethon_generator.codegen mtproto.tl client/src/telethon/_impl/tl/mtproto +``` + +Formatting, type-checking and testing: + +``` +./check.sh ``` diff --git a/client/src/telethon/_impl/tl/core/reader.py b/client/src/telethon/_impl/tl/core/reader.py index cd608a36..8d5f5338 100644 --- a/client/src/telethon/_impl/tl/core/reader.py +++ b/client/src/telethon/_impl/tl/core/reader.py @@ -1,5 +1,5 @@ import struct -from typing import TYPE_CHECKING, Any, Type, TypeVar +from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar if TYPE_CHECKING: from .serializable import Serializable @@ -8,6 +8,26 @@ if TYPE_CHECKING: T = TypeVar("T", bound="Serializable") +def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]: + # Lazy import because generate code depends on the Reader. + # After the first call, the class method is replaced with direct access. + if Reader._get_ty is _bootstrap_get_ty: + from ..layer import TYPE_MAPPING as API_TYPES + from ..mtproto.layer import TYPE_MAPPING as MTPROTO_TYPES + + if API_TYPES.keys() & MTPROTO_TYPES.keys(): + raise RuntimeError( + "generated api and mtproto schemas cannot have colliding constructor identifiers" + ) + ALL_TYPES = API_TYPES | MTPROTO_TYPES + + # Signatures don't fully match, but this is a private method + # and all previous uses are compatible with `dict.get`. + Reader._get_ty = ALL_TYPES.get # type: ignore [assignment] + + return Reader._get_ty(constructor_id) + + class Reader: __slots__ = ("_buffer", "_pos", "_view") @@ -44,11 +64,7 @@ class Reader: return data - @staticmethod - def _get_ty(_: int) -> Type["Serializable"]: - # Implementation replaced during import to prevent cycles, - # without the performance hit of having the import inside. - raise NotImplementedError + _get_ty = staticmethod(_bootstrap_get_ty) def read_serializable(self, cls: Type[T]) -> T: # Calls to this method likely need to ignore "type-abstract". diff --git a/client/src/telethon/_impl/tl/core/serializable.py b/client/src/telethon/_impl/tl/core/serializable.py index e5072559..622c0a32 100644 --- a/client/src/telethon/_impl/tl/core/serializable.py +++ b/client/src/telethon/_impl/tl/core/serializable.py @@ -38,6 +38,13 @@ class Serializable(abc.ABC): attrs = ", ".join(repr(getattr(self, attr)) for attr in self.__slots__) return f"{self.__class__.__name__}({attrs})" + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return NotImplemented + return all( + getattr(self, attr) == getattr(other, attr) for attr in self.__slots__ + ) + def serialize_bytes_to(buffer: bytearray, data: bytes) -> None: length = len(data) diff --git a/client/src/telethon/_impl/tl/mtproto/__init__.py b/client/src/telethon/_impl/tl/mtproto/__init__.py new file mode 100644 index 00000000..1d0af0b3 --- /dev/null +++ b/client/src/telethon/_impl/tl/mtproto/__init__.py @@ -0,0 +1,4 @@ +from . import abcs, functions, types +from .layer import TYPE_MAPPING + +__all__ = ["abcs", "functions", "types", "TYPE_MAPPING"] diff --git a/client/tests/reader_test.py b/client/tests/reader_test.py index 52636f81..04946f45 100644 --- a/client/tests/reader_test.py +++ b/client/tests/reader_test.py @@ -1,5 +1,10 @@ +import struct + from pytest import mark from telethon._impl.tl.core import Reader +from telethon._impl.tl.core.serializable import Serializable +from telethon._impl.tl.mtproto.types import BadServerSalt +from telethon._impl.tl.types import GeoPoint @mark.parametrize( @@ -24,3 +29,21 @@ sentence made it past!", def test_string(string: str, prefix: bytes, suffix: bytes) -> None: data = prefix + string.encode("ascii") + suffix assert str(Reader(data).read_bytes(), "ascii") == string + + +@mark.parametrize( + "obj", + [ + GeoPoint(long=12.34, lat=56.78, access_hash=123123, accuracy_radius=100), + BadServerSalt( + bad_msg_id=1234, + bad_msg_seqno=5678, + error_code=9876, + new_server_salt=5432, + ), + ], +) +def test_generated_object(obj: Serializable) -> None: + assert bytes(obj)[:4] == struct.pack(" None: params = "".join( f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params ) - writer.write(f" def __init__(_s{params}) -> None:") + writer.write(f" def __init__(_s, *{params}) -> None:") for p in property_params: writer.write(f" _s.{p.name} = {p.name}") @@ -183,7 +183,4 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: for name in sorted(generated_type_names): writer.write(f" types.{name},") writer.write("))}") - writer.write( - "Reader._get_ty = TYPE_MAPPING.get # type: ignore [method-assign, assignment]" - ) writer.write(f"__all__ = ['LAYER', 'TYPE_MAPPING']")