Port mtproto from grammers

This commit is contained in:
Lonami Exo
2023-07-09 21:16:55 +02:00
parent 9636ef35c1
commit 269ee4f05f
35 changed files with 1747 additions and 57 deletions

View File

@@ -37,6 +37,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
generated_types = {
"True",
"Bool",
"Object",
} # initial set is considered to be "compiler built-ins"
ignored_types = {"true", "boolTrue", "boolFalse"} # also "compiler built-ins"
@@ -91,7 +92,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
writer.write(f"import struct")
writer.write(f"from typing import List, Optional, Self")
writer.write(f"from .. import abcs")
writer.write(f"from ..core import Reader, serialize_bytes_to")
writer.write(f"from ..core import Reader, Serializable, serialize_bytes_to")
ns = f"{typedef.namespace[0]}." if typedef.namespace else ""
generated_type_names.add(f"{ns}{to_class_name(typedef.name)}")
@@ -160,8 +161,11 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
writer.write(f"from ..core import Request, serialize_bytes_to")
# def name(params, ...)
params = ", ".join(f"{p.name}: {param_type_fmt(p.ty)}" for p in required_params)
writer.write(f"def {to_method_name(functiondef.name)}({params}) -> Request:")
params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in required_params)
star = "*" if params else ""
writer.write(
f"def {to_method_name(functiondef.name)}({star}{params}) -> Request:"
)
writer.indent(2)
generate_function(writer, functiondef)
writer.dedent(2)

View File

@@ -1,18 +1,28 @@
import re
from typing import Iterator
from typing import Iterator, List
from ....tl_parser import BaseParameter, FlagsParameter, NormalParameter, Type
def split_words(name: str) -> List[str]:
return re.findall(
r"""
^$
|[a-z\d]+
|[A-Z][A-Z\d]+(?=[A-Z]|_|$)
|[A-Z][a-z\d]+
""",
name,
re.VERBOSE,
)
def to_class_name(name: str) -> str:
return re.sub(r"(?:^|_)([a-z])", lambda m: m[1].upper(), name)
return "".join(word.title() for word in split_words(name))
def to_method_name(name: str) -> str:
snake_case = re.sub(
r"_+[A-Za-z]+|[A-Z]*[a-z]+", lambda m: "_" + m[0].replace("_", "").lower(), name
)
return snake_case.strip("_")
return "_".join(word.lower() for word in split_words(name))
def gen_tmp_names() -> Iterator[str]:
@@ -69,6 +79,8 @@ def inner_type_fmt(ty: Type) -> str:
return to_class_name(ty.name)
elif ty.generic_ref:
return "bytes"
elif ty.name == "Object":
return "Serializable"
else:
ns = (".".join(ty.namespace) + ".") if ty.namespace else ""
return f"abcs.{ns}{to_class_name(ty.name)}"
@@ -91,7 +103,7 @@ def param_type_fmt(ty: BaseParameter) -> str:
else:
inner_ty = ty.ty
res = inner_type_fmt(inner_ty)
res = "bytes" if inner_ty.name == "Object" else inner_type_fmt(inner_ty)
if ty.ty.generic_arg:
res = f"List[{res}]"

View File

@@ -6,8 +6,15 @@ from ....tl_parser import Definition, NormalParameter, Parameter, Type
from ..fakefs import SourceWriter
from .common import inner_type_fmt, is_trivial, to_class_name, trivial_struct_fmt
# Some implementations choose to create these types by hand.
# For consistency, we instead special-case the generator.
SPECIAL_CASED_OBJECT_READS = {
0xF35C6D01: "reader.read_remaining()", # rpc_result
0x5BB8E511: "reader.read(_bytes)", # message
}
def reader_read_fmt(ty: Type) -> Tuple[str, Optional[str]]:
def reader_read_fmt(ty: Type, constructor_id: int) -> Tuple[str, Optional[str]]:
if is_trivial(NormalParameter(ty=ty, flag=None)):
fmt = trivial_struct_fmt(NormalParameter(ty=ty, flag=None))
size = struct.calcsize(f"<{fmt}")
@@ -17,17 +24,22 @@ def reader_read_fmt(ty: Type) -> Tuple[str, Optional[str]]:
elif ty.name == "bytes":
return f"reader.read_bytes()", None
elif ty.name == "int128":
return f"int.from_bytes(reader.read(16), 'little', signed=True)", None
return f"int.from_bytes(reader.read(16))", None
elif ty.name == "int256":
return f"int.from_bytes(reader.read(32), 'little', signed=True)", None
return f"int.from_bytes(reader.read(32))", None
elif ty.bare:
return f"{to_class_name(ty.name)}._read_from(reader)", None
elif ty.name == "Object":
try:
return SPECIAL_CASED_OBJECT_READS[constructor_id], None
except KeyError:
raise NotImplementedError("missing special case for object read")
else:
return f"reader.read_serializable({inner_type_fmt(ty)})", "type-abstract"
def generate_normal_param_read(
writer: SourceWriter, name: str, param: NormalParameter
writer: SourceWriter, name: str, param: NormalParameter, constructor_id: int
) -> None:
flag_check = f"_{param.flag.name} & {1 << param.flag.index}" if param.flag else None
if param.ty.name == "true":
@@ -59,7 +71,7 @@ def generate_normal_param_read(
fmt = trivial_struct_fmt(generic)
size = struct.calcsize(f"<{fmt}")
writer.write(
f"_{name} = reader.read_fmt(f'<{{__len}}{fmt}', __len * {size})[0]"
f"_{name} = [*reader.read_fmt(f'<{{__len}}{fmt}', __len * {size})]"
)
if param.ty.generic_arg.name == "Bool":
writer.write(
@@ -67,11 +79,13 @@ def generate_normal_param_read(
)
writer.write(f"_{name} = [_{name} == 0x997275b5]")
else:
fmt_read, type_ignore = reader_read_fmt(param.ty.generic_arg)
fmt_read, type_ignore = reader_read_fmt(
param.ty.generic_arg, constructor_id
)
comment = f" # type: ignore [{type_ignore}]" if type_ignore else ""
writer.write(f"_{name} = [{fmt_read} for _ in range(__len)]{comment}")
else:
fmt_read, type_ignore = reader_read_fmt(param.ty)
fmt_read, type_ignore = reader_read_fmt(param.ty, constructor_id)
comment = f" # type: ignore [{type_ignore}]" if type_ignore else ""
writer.write(f"_{name} = {fmt_read}{comment}")
@@ -97,4 +111,4 @@ def generate_read(writer: SourceWriter, defn: Definition) -> None:
for param in iter:
if not isinstance(param.ty, NormalParameter):
raise RuntimeError("FlagsParameter should be considered trivial")
generate_normal_param_read(writer, param.name, param.ty)
generate_normal_param_read(writer, param.name, param.ty, defn.id)

View File

@@ -26,16 +26,16 @@ def generate_buffer_append(
)
else:
writer.write(f"{buffer} += struct.pack(f'<{fmt}', {name})")
elif ty.generic_ref:
elif ty.generic_ref or ty.name == "Object":
writer.write(f"{buffer} += {name}") # assume previously-serialized
elif ty.name == "string":
writer.write(f"serialize_bytes_to({buffer}, {name}.encode('utf-8'))")
elif ty.name == "bytes":
writer.write(f"serialize_bytes_to({buffer}, {name})")
elif ty.name == "int128":
writer.write(f"{buffer} += {name}.to_bytes(16, 'little', signed=True)")
writer.write(f"{buffer} += {name}.to_bytes(16)")
elif ty.name == "int256":
writer.write(f"{buffer} += {name}.to_bytes(32, 'little', signed=True)")
writer.write(f"{buffer} += {name}.to_bytes(32)")
elif ty.bare:
writer.write(f"{name}._write_to({buffer})")
else:

View File

@@ -37,7 +37,7 @@ def iterate(contents: str) -> Iterator[TypeDef | FunctionDef | Exception]:
definition = definition[len(FUNCTIONS_SEP) :].strip()
elif definition.startswith(TYPES_SEP):
cls = TypeDef
definition = definition[len(FUNCTIONS_SEP) :].strip()
definition = definition[len(TYPES_SEP) :].strip()
else:
raise ValueError("bad separator")