From 9ba6e2ded67b48816951381b2ba489c5f8f6a46d Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Fri, 1 Sep 2023 13:25:17 +0200 Subject: [PATCH] Generate deserializers for requests --- client/src/telethon/_impl/tl/core/__init__.py | 23 ++++++++- client/src/telethon/_impl/tl/core/reader.py | 47 ++++++++++++++++++- client/src/telethon/_impl/tl/core/request.py | 28 ++++++++++- .../_impl/codegen/generator.py | 20 ++++++-- .../_impl/codegen/serde/deserialization.py | 40 ++++++++++++++++ 5 files changed, 150 insertions(+), 8 deletions(-) diff --git a/client/src/telethon/_impl/tl/core/__init__.py b/client/src/telethon/_impl/tl/core/__init__.py index eb86687f..0d1b1b23 100644 --- a/client/src/telethon/_impl/tl/core/__init__.py +++ b/client/src/telethon/_impl/tl/core/__init__.py @@ -1,5 +1,24 @@ -from .reader import Reader +from .reader import ( + Reader, + deserialize_bool, + deserialize_i32_list, + deserialize_i64_list, + deserialize_identity, + list_deserializer, + single_deserializer, +) from .request import Request from .serializable import Serializable, serialize_bytes_to -__all__ = ["Reader", "Request", "Serializable", "serialize_bytes_to"] +__all__ = [ + "Reader", + "deserialize_bool", + "deserialize_i32_list", + "deserialize_i64_list", + "deserialize_identity", + "list_deserializer", + "single_deserializer", + "Request", + "Serializable", + "serialize_bytes_to", +] diff --git a/client/src/telethon/_impl/tl/core/reader.py b/client/src/telethon/_impl/tl/core/reader.py index 703b3d1b..7dd0ca03 100644 --- a/client/src/telethon/_impl/tl/core/reader.py +++ b/client/src/telethon/_impl/tl/core/reader.py @@ -1,5 +1,6 @@ +import functools import struct -from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Type, TypeVar if TYPE_CHECKING: from .serializable import Serializable @@ -85,3 +86,47 @@ class Reader: raise ValueError(f"No type found for constructor ID: {cid:x}") assert issubclass(ty, cls) return ty._read_from(self) + + +@functools.cache +def single_deserializer(cls: Type[T]) -> Callable[[bytes], T]: + def deserializer(body: bytes) -> T: + return Reader(body).read_serializable(cls) + + return deserializer + + +@functools.cache +def list_deserializer(cls: Type[T]) -> Callable[[bytes], List[T]]: + def deserializer(body: bytes) -> List[T]: + reader = Reader(body) + vec_id, length = reader.read_fmt("= 0 + return [reader.read_serializable(cls) for _ in range(length)] + + return deserializer + + +def deserialize_i64_list(body: bytes) -> List[int]: + reader = Reader(body) + vec_id, length = reader.read_fmt("= 0 + return [*reader.read_fmt(f"<{length}q", length * 8)] + + +def deserialize_i32_list(body: bytes) -> List[int]: + reader = Reader(body) + vec_id, length = reader.read_fmt("= 0 + return [*reader.read_fmt(f"<{length}i", length * 4)] + + +def deserialize_identity(body: bytes) -> bytes: + return body + + +def deserialize_bool(body: bytes) -> bool: + reader = Reader(body) + bool_id = reader.read_fmt(" Optional[Callable[[bytes], Any]]: + # Similar to Reader's bootstrapping. + if Request._get_deserializer is _bootstrap_get_deserializer: + from ..layer import RESPONSE_MAPPING as API_DESER + from ..mtproto.layer import RESPONSE_MAPPING as MTPROTO_DESER + + if API_DESER.keys() & MTPROTO_DESER.keys(): + raise RuntimeError( + "generated api and mtproto schemas cannot have colliding constructor identifiers" + ) + ALL_DESER = API_DESER | MTPROTO_DESER + + Request._get_deserializer = ALL_DESER.get # type: ignore [assignment] + + return Request._get_deserializer(constructor_id) + + class Request(bytes, Generic[Return]): __slots__ = () @@ -16,5 +35,12 @@ class Request(bytes, Generic[Return]): except struct.error: return 0 + _get_deserializer = staticmethod(_bootstrap_get_deserializer) + + def deserialize_response(self, response: bytes) -> Return: + deserializer = self._get_deserializer(self.constructor_id) + assert deserializer is not None + return deserializer(response) # type: ignore [no-any-return] + def debug_name(self) -> str: return f"request#{self.constructor_id:x}" diff --git a/generator/src/telethon_generator/_impl/codegen/generator.py b/generator/src/telethon_generator/_impl/codegen/generator.py index 041ef55a..4a2fc9ae 100644 --- a/generator/src/telethon_generator/_impl/codegen/generator.py +++ b/generator/src/telethon_generator/_impl/codegen/generator.py @@ -11,7 +11,11 @@ from .serde.common import ( to_class_name, to_method_name, ) -from .serde.deserialization import generate_read, param_value_fmt +from .serde.deserialization import ( + function_deserializer_fmt, + generate_read, + param_value_fmt, +) from .serde.serialization import generate_function, generate_write @@ -178,8 +182,10 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: ) writer = fs.open(Path("layer.py")) - writer.write(f"from . import types") - writer.write(f"from .core import Serializable, Reader") + writer.write(f"from . import abcs, types") + writer.write( + f"from .core import Serializable, Reader, deserialize_bool, deserialize_i32_list, deserialize_i64_list, deserialize_identity, single_deserializer, list_deserializer" + ) writer.write(f"from typing import cast, Tuple, Type") writer.write(f"LAYER = {tl.layer!r}") writer.write( @@ -188,4 +194,10 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: for name in sorted(generated_type_names): writer.write(f" types.{name},") writer.write("))}") - writer.write(f"__all__ = ['LAYER', 'TYPE_MAPPING']") + writer.write("RESPONSE_MAPPING = {") + for functiondef in tl.functiondefs: + writer.write( + f" {hex(functiondef.id)}: {function_deserializer_fmt(functiondef)}," + ) + writer.write("}") + writer.write(f"__all__ = ['LAYER', 'TYPE_MAPPING', 'RESPONSE_MAPPING']") diff --git a/generator/src/telethon_generator/_impl/codegen/serde/deserialization.py b/generator/src/telethon_generator/_impl/codegen/serde/deserialization.py index 1faf69d8..10c73a6e 100644 --- a/generator/src/telethon_generator/_impl/codegen/serde/deserialization.py +++ b/generator/src/telethon_generator/_impl/codegen/serde/deserialization.py @@ -122,3 +122,43 @@ def param_value_fmt(param: Parameter) -> str: return f"_{param.name} == 0x997275b5" else: return f"_{param.name}" + + +def function_deserializer_fmt(defn: Definition) -> str: + if defn.ty.generic_arg: + if defn.ty.name != ("Vector"): + raise NotImplementedError( + "generic_arg return for non-boxed-vectors not implemented" + ) + elif defn.ty.generic_ref: + raise NotImplementedError( + "return for generic refs inside vector not implemented" + ) + elif is_trivial(NormalParameter(ty=defn.ty.generic_arg, flag=None)): + if defn.ty.generic_arg.name == "int": + return "deserialize_i32_list" + elif defn.ty.generic_arg.name == "long": + return "deserialize_i64_list" + else: + raise NotImplementedError( + f"return for trivial arg {defn.ty.generic_arg} not implemented" + ) + elif defn.ty.generic_arg.bare: + raise NotImplementedError( + "return for non-boxed serializables inside a vector not implemented" + ) + else: + return f"list_deserializer({inner_type_fmt(defn.ty.generic_arg)})" + elif defn.ty.generic_ref: + return "deserialize_identity" + elif is_trivial(NormalParameter(ty=defn.ty, flag=None)): + if defn.ty.name == "Bool": + return "deserialize_bool" + else: + raise NotImplementedError( + f"return for trivial arg {defn.ty} not implemented" + ) + elif defn.ty.bare: + raise NotImplementedError("return for non-boxed serializables not implemented") + else: + return f"single_deserializer({inner_type_fmt(defn.ty)})"