Port tl-gen from grammers

This commit is contained in:
Lonami Exo
2023-07-03 19:19:20 +02:00
parent fc6984d423
commit fed06f40ed
39 changed files with 2207 additions and 40 deletions

View File

@@ -0,0 +1,4 @@
from . import codegen, tl_parser
from .version import __version__
__all__ = ["codegen", "tl_parser"]

View File

@@ -0,0 +1,5 @@
from .fakefs import FakeFs, SourceWriter
from .generator import generate
from .loader import ParsedTl, load_tl_file
__all__ = ["FakeFs", "SourceWriter", "generate", "ParsedTl", "load_tl_file"]

View File

@@ -0,0 +1,44 @@
import weakref
from pathlib import Path
from typing import Dict
class FakeFs:
def __init__(self) -> None:
self._files: Dict[Path, bytearray] = {}
def open(self, path: Path) -> "SourceWriter":
return SourceWriter(self, path)
def write(self, path: Path, line: str) -> None:
file = self._files.get(path)
if file is None:
self._files[path] = file = bytearray()
file += line.encode("utf-8")
def materialize(self, root: Path) -> None:
for stem, data in self._files.items():
path = root / stem
path.parent.mkdir(exist_ok=True)
with path.open("wb") as fd:
fd.write(data)
def __contains__(self, path: Path) -> bool:
return path in self._files
class SourceWriter:
def __init__(self, fs: FakeFs, path: Path) -> None:
self._fs = weakref.ref(fs)
self._path = path
self._indent = ""
def write(self, string: str) -> None:
if fs := self._fs():
fs.write(self._path, f"{self._indent}{string}\n")
def indent(self, n: int = 1) -> None:
self._indent += " " * n
def dedent(self, n: int = 1) -> None:
self._indent = self._indent[: -2 * n]

View File

@@ -0,0 +1,178 @@
from pathlib import Path
from typing import Set
from .fakefs import FakeFs, SourceWriter
from .loader import ParsedTl
from .serde.common import (
inner_type_fmt,
is_computed,
param_type_fmt,
to_class_name,
to_method_name,
)
from .serde.deserialization import generate_read
from .serde.serialization import generate_function, generate_write
def generate_init(writer: SourceWriter, namespaces: Set[str]) -> None:
sorted_ns = list(namespaces)
sorted_ns.sort()
if sorted_ns:
sorted_import = ", ".join(sorted_ns)
writer.write(f"from ._nons import *")
writer.write(f"from . import {sorted_import}")
sorted_all = ", ".join(f"{ns!r}" for ns in sorted_ns)
writer.write(f"__all__ = [{sorted_all}]")
def generate(fs: FakeFs, tl: ParsedTl) -> None:
generated_types = {
"True",
"Bool",
} # initial set is considered to be "compiler built-ins"
ignored_types = {"true", "boolTrue", "boolFalse"} # also "compiler built-ins"
abc_namespaces = set()
type_namespaces = set()
function_namespaces = set()
generated_type_names = []
for typedef in tl.typedefs:
if typedef.ty.full_name not in generated_types:
if len(typedef.ty.namespace) >= 2:
raise NotImplementedError("nested abc-namespaces are not supported")
elif len(typedef.ty.namespace) == 1:
abc_namespaces.add(typedef.ty.namespace[0])
abc_path = (Path("abcs") / typedef.ty.namespace[0]).with_suffix(".py")
else:
abc_path = Path("abcs/_nons.py")
if abc_path not in fs:
fs.write(abc_path, "from abc import ABCMeta\n")
fs.write(abc_path, "from ..core.serializable import Serializable\n")
fs.write(
abc_path,
f"class {to_class_name(typedef.ty.name)}(Serializable, metaclass=ABCMeta): pass\n",
)
generated_types.add(typedef.ty.full_name)
if typedef.name in ignored_types:
continue
property_params = [p for p in typedef.params if not is_computed(p.ty)]
if len(typedef.namespace) >= 2:
raise NotImplementedError("nested type-namespaces are not supported")
elif len(typedef.namespace) == 1:
type_namespaces.add(typedef.namespace[0])
type_path = (Path("types") / typedef.namespace[0]).with_suffix(".py")
else:
type_path = Path("types/_nons.py")
writer = fs.open(type_path)
if type_path not in fs:
writer.write(f"import struct")
writer.write(f"from typing import List, Optional, Self")
writer.write(f"from .. import abcs")
writer.write(f"from ..core.reader import Reader")
writer.write(f"from ..core.serializable import serialize_bytes_to")
ns = f"{typedef.namespace[0]}." if typedef.namespace else ""
generated_type_names.append(f"{ns}{to_class_name(typedef.name)}")
# class Type(BaseType)
writer.write(
f"class {to_class_name(typedef.name)}({inner_type_fmt(typedef.ty)}):"
)
# __slots__ = ('params', ...)
slots = " ".join(f"'{p.name}'," for p in property_params)
writer.write(f" __slots__ = ({slots})")
# def constructor_id()
writer.write(f" @classmethod")
writer.write(f" def constructor_id(_) -> int:")
writer.write(f" return {hex(typedef.id)}")
# def __init__()
if property_params:
params = "".join(
f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params
)
writer.write(f" def __init__(_s{params}) -> None:")
for p in property_params:
writer.write(f" _s.{p.name} = {p.name}")
# def _read_from()
writer.write(f" @classmethod")
writer.write(f" def _read_from(cls, reader: Reader) -> Self:")
writer.indent(2)
generate_read(writer, typedef)
params = ", ".join(f"{p.name}=_{p.name}" for p in property_params)
writer.write(f"return cls({params})")
writer.dedent(2)
# def _write_to()
writer.write(f" def _write_to(self, buffer: bytearray) -> None:")
if typedef.params:
writer.indent(2)
generate_write(writer, typedef)
writer.dedent(2)
else:
writer.write(f" pass")
for functiondef in tl.functiondefs:
required_params = [p for p in functiondef.params if not is_computed(p.ty)]
if len(functiondef.namespace) >= 2:
raise NotImplementedError("nested function-namespaces are not supported")
elif len(functiondef.namespace) == 1:
function_namespaces.add(functiondef.namespace[0])
function_path = (Path("functions") / functiondef.namespace[0]).with_suffix(
".py"
)
else:
function_path = Path("functions/_nons.py")
writer = fs.open(function_path)
if function_path not in fs:
writer.write(f"import struct")
writer.write(f"from typing import List, Optional, Self")
writer.write(f"from .. import abcs")
writer.write(f"from ..core.request import Request")
writer.write(f"from ..core.serializable import serialize_bytes_to")
# def name(params, ...)
params = ", ".join(f"{p.name}: {param_type_fmt(p.ty)}" for p in required_params)
writer.write(f"def {to_method_name(functiondef.name)}({params}) -> Request:")
writer.indent(2)
generate_function(writer, functiondef)
writer.dedent(2)
generate_init(fs.open(Path("abcs/__init__.py")), abc_namespaces)
generate_init(fs.open(Path("types/__init__.py")), type_namespaces)
generate_init(fs.open(Path("functions/__init__.py")), function_namespaces)
generated_type_names.sort()
writer = fs.open(Path("layer.py"))
writer.write(f"from . import types")
writer.write(f"from .core import Serializable, Reader")
writer.write(f"from typing import cast, Tuple, Type")
writer.write(f"LAYER = {tl.layer!r}")
writer.write(
"TYPE_MAPPING = {t.constructor_id(): t for t in cast(Tuple[Type[Serializable]], ("
)
for name in generated_type_names:
writer.write(f" types.{name},")
writer.write("))}")
writer.write(
"Reader._get_ty = TYPE_MAPPING.get # type: ignore [method-assign, assignment]"
)
writer.write(f"__all__ = ['LAYER', 'TYPE_MAPPING']")

View File

@@ -0,0 +1,39 @@
import re
from dataclasses import dataclass
from typing import List, Optional
from ...tl_parser import Definition, FunctionDef, TypeDef, parse_tl_file
@dataclass
class ParsedTl:
layer: Optional[int]
typedefs: List[Definition]
functiondefs: List[Definition]
def load_tl_file(path: str) -> ParsedTl:
typedefs, functiondefs = [], []
with open(path, "r", encoding="utf-8") as fd:
contents = fd.read()
if m := re.search(r"//\s*LAYER\s+(\d+)", contents):
layer = int(m[1])
else:
layer = None
for definition in parse_tl_file(contents):
if isinstance(definition, Exception):
# generic types (such as vector) is known to not be implemented
if definition.args[0] != "not implemented":
raise
elif isinstance(definition, TypeDef):
typedefs.append(definition)
elif isinstance(definition, FunctionDef):
functiondefs.append(definition)
else:
raise TypeError(f"unexpected type: {type(definition)}")
return ParsedTl(
layer=layer, typedefs=list(typedefs), functiondefs=list(functiondefs)
)

View File

@@ -0,0 +1,102 @@
import re
from typing import Iterator
from ....tl_parser import BaseParameter, FlagsParameter, NormalParameter, Type
def to_class_name(name: str) -> str:
return re.sub(r"(?:^|_)([a-z])", lambda m: m[1].upper(), name)
def to_method_name(name: str) -> str:
snake_case = re.sub(
r"_+[A-Za-z]+|[A-Z]*[a-z]+", lambda m: "_" + m[0].replace("_", "").lower(), name
)
return snake_case.strip("_")
def gen_tmp_names() -> Iterator[str]:
i = 0
while True:
yield f"_t{i}"
i += 1
def is_computed(ty: BaseParameter) -> bool:
return isinstance(ty, FlagsParameter)
def is_trivial(ty: BaseParameter) -> bool:
return (
isinstance(ty, FlagsParameter)
or isinstance(ty, NormalParameter)
and not ty.flag
and ty.ty.name in ("int", "long", "double", "Bool")
)
_TRIVIAL_STRUCT_MAP = {"int": "i", "long": "q", "double": "d", "Bool": "I"}
def trivial_struct_fmt(ty: BaseParameter) -> str:
try:
return (
_TRIVIAL_STRUCT_MAP[ty.ty.name] if isinstance(ty, NormalParameter) else "I"
)
except KeyError:
raise ValueError("input param was not trivial")
_INNER_TYPE_MAP = {
"Bool": "bool",
"true": "bool",
"int": "int",
"long": "int",
"int128": "int",
"int256": "int",
"double": "float",
"bytes": "bytes",
"string": "str",
}
def inner_type_fmt(ty: Type) -> str:
builtin_ty = _INNER_TYPE_MAP.get(ty.name)
if builtin_ty:
return builtin_ty
elif ty.bare:
return to_class_name(ty.name)
elif ty.generic_ref:
return "bytes"
else:
ns = (".".join(ty.namespace) + ".") if ty.namespace else ""
return f"abcs.{ns}{to_class_name(ty.name)}"
def param_type_fmt(ty: BaseParameter) -> str:
if isinstance(ty, FlagsParameter):
return "int"
elif not isinstance(ty, NormalParameter):
raise TypeError("unexpected input type {ty}")
inner_ty: Type
if ty.ty.generic_arg:
if ty.ty.name not in ("Vector", "vector"):
raise NotImplementedError(
"generic_arg type for non-vectors not implemented"
)
inner_ty = ty.ty.generic_arg
else:
inner_ty = ty.ty
res = inner_type_fmt(inner_ty)
if ty.ty.generic_arg:
res = f"List[{res}]"
if ty.flag and ty.ty.name != "true":
res = f"Optional[{res}]"
return res

View File

@@ -0,0 +1,100 @@
import struct
from itertools import groupby
from typing import Optional, Tuple
from ....tl_parser import Definition, NormalParameter, Parameter, Type
from ..fakefs import SourceWriter
from .common import inner_type_fmt, is_trivial, to_class_name, trivial_struct_fmt
def reader_read_fmt(ty: Type) -> Tuple[str, Optional[str]]:
if is_trivial(NormalParameter(ty=ty, flag=None)):
fmt = trivial_struct_fmt(NormalParameter(ty=ty, flag=None))
size = struct.calcsize(f"<{fmt}")
return f"reader.read_fmt(f'<{fmt}', {size})[0]", None
elif ty.name == "string":
return f"str(reader.read_bytes(), 'utf-8', 'replace')", None
elif ty.name == "bytes":
return f"reader.read_bytes()", None
elif ty.name == "int128":
return f"int.from_bytes(reader.read(16), 'little', signed=True)", None
elif ty.name == "int256":
return f"int.from_bytes(reader.read(32), 'little', signed=True)", None
elif ty.bare:
return f"{to_class_name(ty.name)}._read_from(reader)", None
else:
return f"reader.read_serializable({inner_type_fmt(ty)})", "type-abstract"
def generate_normal_param_read(
writer: SourceWriter, name: str, param: NormalParameter
) -> None:
flag_check = f"_{param.flag.name} & {1 << param.flag.index}" if param.flag else None
if param.ty.name == "true":
if not flag_check:
raise NotImplementedError("true parameter is expected to be a flag")
writer.write(f"_{name} = ({flag_check}) != 0")
elif param.ty.generic_ref:
raise NotImplementedError("generic_ref deserialization not implemented")
else:
if flag_check:
writer.write(f"if {flag_check}:")
writer.indent()
if param.ty.generic_arg:
if param.ty.name not in ("Vector", "vector"):
raise NotImplementedError(
"generic_arg deserialization for non-vectors not implemented"
)
if param.ty.bare:
writer.write(f"__len = reader.read_fmt('<i', 4)[0]")
writer.write(f"assert __len >= 0")
else:
writer.write(f"__vid, __len = reader.read_fmt('<ii', 8)")
writer.write(f"assert __vid == 0x1cb5c415 and __len >= 0")
generic = NormalParameter(ty=param.ty.generic_arg, flag=None)
if is_trivial(generic):
fmt = trivial_struct_fmt(generic)
size = struct.calcsize(f"<{fmt}")
writer.write(
f"_{name} = reader.read_fmt(f'<{{__len}}{fmt}', __len * {size})[0]"
)
if param.ty.generic_arg.name == "Bool":
writer.write(
f"assert all(__x in (0xbc799737, 0x0x997275b5) for __x in _{name})"
)
writer.write(f"_{name} = [_{name} == 0x997275b5]")
else:
fmt_read, type_ignore = reader_read_fmt(param.ty.generic_arg)
comment = f" # type: ignore [{type_ignore}]" if type_ignore else ""
writer.write(f"_{name} = [{fmt_read} for _ in range(__len)]{comment}")
else:
fmt_read, type_ignore = reader_read_fmt(param.ty)
comment = f" # type: ignore [{type_ignore}]" if type_ignore else ""
writer.write(f"_{name} = {fmt_read}{comment}")
if flag_check:
writer.dedent()
writer.write(f"else:")
writer.write(f" _{name} = None")
def generate_read(writer: SourceWriter, defn: Definition) -> None:
for trivial, iter in groupby(
defn.params,
key=lambda p: is_trivial(p.ty),
):
if trivial:
# As an optimization, struct.unpack can handle more than one element at a time.
group = list(iter)
names = "".join(f"_{param.name}, " for param in group)
fmt = "".join(trivial_struct_fmt(param.ty) for param in group)
size = struct.calcsize(f"<{fmt}")
writer.write(f"{names}= reader.read_fmt('<{fmt}', {size})")
else:
for param in iter:
if not isinstance(param.ty, NormalParameter):
raise RuntimeError("FlagsParameter should be considered trivial")
generate_normal_param_read(writer, param.name, param.ty)

View File

@@ -0,0 +1,161 @@
from itertools import groupby
from typing import Iterator
from ....tl_parser import Definition, FlagsParameter, NormalParameter, Parameter, Type
from ..fakefs import SourceWriter
from .common import gen_tmp_names, is_computed, is_trivial, trivial_struct_fmt
def param_value_expr(param: Parameter) -> str:
is_bool = isinstance(param.ty, NormalParameter) and param.ty.ty.name == "Bool"
pre = "0x997275b5 if " if is_bool else ""
mid = f"_{param.name}" if is_computed(param.ty) else f"self.{param.name}"
suf = " else 0xbc799737" if is_bool else ""
return f"{pre}{mid}{suf}"
def generate_buffer_append(
writer: SourceWriter, buffer: str, name: str, ty: Type
) -> None:
if is_trivial(NormalParameter(ty=ty, flag=None)):
fmt = trivial_struct_fmt(NormalParameter(ty=ty, flag=None))
if ty.name == "Bool":
writer.write(
f"{buffer} += struct.pack(f'<{fmt}', (0x997275b5 if {name} else 0xbc799737))"
)
else:
writer.write(f"{buffer} += struct.pack(f'<{fmt}', {name})")
elif ty.generic_ref:
writer.write(f"{buffer} += {name}") # assume previously-serialized
elif ty.name == "string":
writer.write(f"serialize_bytes_to({buffer}, {name}.encode('utf-8'))")
elif ty.name == "bytes":
writer.write(f"serialize_bytes_to({buffer}, {name})")
elif ty.name == "int128":
writer.write(f"{buffer} += {name}.to_bytes(16, 'little', signed=True)")
elif ty.name == "int256":
writer.write(f"{buffer} += {name}.to_bytes(32, 'little', signed=True)")
elif ty.bare:
writer.write(f"{name}._write_to({buffer})")
else:
writer.write(f"{name}._write_boxed_to({buffer})")
def generate_normal_param_write(
writer: SourceWriter,
tmp_names: Iterator[str],
buffer: str,
name: str,
param: NormalParameter,
) -> None:
if param.ty.name == "true":
return # special-cased "built-in"
if param.flag:
writer.write(f"if {name} is not None:")
writer.indent()
if param.ty.generic_arg:
if param.ty.name not in ("Vector", "vector"):
raise NotImplementedError(
"generic_arg deserialization for non-vectors not implemented"
)
if param.ty.bare:
writer.write(f"{buffer} += struct.pack('<i', len({name}))")
else:
writer.write(f"{buffer} += struct.pack('<ii', 0x1cb5c415, len({name}))")
generic = NormalParameter(ty=param.ty.generic_arg, flag=None)
if is_trivial(generic):
fmt = trivial_struct_fmt(generic)
if param.ty.generic_arg.name == "Bool":
tmp = next(tmp_names)
writer.write(
f"{buffer} += struct.pack(f'<{{len({name})}}{fmt}', *(0x997275b5 if {tmp} else 0xbc799737 for {tmp} in {name}))"
)
else:
writer.write(
f"{buffer} += struct.pack(f'<{{len({name})}}{fmt}', *{name})"
)
else:
tmp = next(tmp_names)
writer.write(f"for {tmp} in {name}:")
writer.indent()
generate_buffer_append(writer, buffer, tmp, param.ty.generic_arg)
writer.dedent()
else:
generate_buffer_append(writer, buffer, f"{name}", param.ty)
if param.flag:
writer.dedent()
def generate_write(writer: SourceWriter, defn: Definition) -> None:
tmp_names = gen_tmp_names()
for trivial, iter in groupby(
defn.params,
key=lambda p: is_trivial(p.ty),
):
if trivial:
# As an optimization, struct.pack can handle more than one element at a time.
group = list(iter)
for param in group:
if isinstance(param.ty, FlagsParameter):
flags = " | ".join(
f"({1 << p.ty.flag.index} if self.{p.name} else 0)"
if p.ty.ty.name == "true"
else f"(0 if self.{p.name} is None else {1 << p.ty.flag.index})"
for p in defn.params
if isinstance(p.ty, NormalParameter)
and p.ty.flag
and p.ty.flag.name == param.name
)
writer.write(f"_{param.name} = {flags or 0}")
names = ", ".join(map(param_value_expr, group))
fmt = "".join(trivial_struct_fmt(param.ty) for param in group)
writer.write(f"buffer += struct.pack('<{fmt}', {names})")
else:
for param in iter:
if not isinstance(param.ty, NormalParameter):
raise RuntimeError("FlagsParameter should be considered trivial")
generate_normal_param_write(
writer, tmp_names, "buffer", f"self.{param.name}", param.ty
)
def generate_function(writer: SourceWriter, defn: Definition) -> None:
tmp_names = gen_tmp_names()
writer.write("_buffer = bytearray()")
for trivial, iter in groupby(
defn.params,
key=lambda p: is_trivial(p.ty),
):
if trivial:
# As an optimization, struct.pack can handle more than one element at a time.
group = list(iter)
for param in group:
if isinstance(param.ty, FlagsParameter):
flags = " | ".join(
f"({1 << p.ty.flag.index} if {p.name} else 0)"
if p.ty.ty.name == "true"
else f"(0 if {p.name} is None else {1 << p.ty.flag.index})"
for p in defn.params
if isinstance(p.ty, NormalParameter)
and p.ty.flag
and p.ty.flag.name == param.name
)
writer.write(f"{param.name} = {flags or 0}")
names = ", ".join(p.name for p in group)
fmt = "".join(trivial_struct_fmt(param.ty) for param in group)
writer.write(f"_buffer += struct.pack('<{fmt}', {names})")
else:
for param in iter:
if not isinstance(param.ty, NormalParameter):
raise RuntimeError("FlagsParameter should be considered trivial")
generate_normal_param_write(
writer, tmp_names, "_buffer", param.name, param.ty
)
writer.write("return Request(b'' + _buffer)")

View File

@@ -0,0 +1,118 @@
from dataclasses import dataclass
from typing import List, Self, Set
from ..utils import infer_id
from .parameter import Parameter, TypeDefNotImplemented
from .parameter_type import FlagsParameter, NormalParameter
from .ty import Type
@dataclass
class Definition:
namespace: List[str]
name: str
id: int
params: List[Parameter]
ty: Type
@classmethod
def from_str(cls, definition: str) -> Self:
if not definition or definition.isspace():
raise ValueError("empty")
parts = definition.split("=")
if len(parts) < 2:
raise ValueError("missing type")
left, ty_str, *_ = map(str.strip, parts)
try:
ty = Type.from_str(ty_str)
except ValueError as e:
if e.args[0] == "empty":
raise ValueError("missing type")
else:
raise
if (pos := left.find(" ")) != -1:
name, middle = left[:pos], left[pos:].strip()
else:
name, middle = left.strip(), ""
parts = name.split("#")
if len(parts) < 2:
name, id_str = parts[0], None
else:
name, id_str, *_ = parts
namespace = name.split(".")
if not all(namespace):
raise ValueError("missing name")
name = namespace.pop()
if id_str is None:
id = infer_id(definition)
else:
try:
id = int(id_str, 16)
except ValueError:
raise ValueError("invalid id")
type_defs: List[str] = []
flag_defs = []
params = []
for param_str in middle.split():
try:
param = Parameter.from_str(param_str)
except TypeDefNotImplemented as e:
type_defs.append(e.name)
continue
if isinstance(param.ty, FlagsParameter):
flag_defs.append(param.name)
elif not isinstance(param.ty, NormalParameter):
raise NotImplementedError
elif param.ty.ty.generic_ref and param.ty.ty.name not in type_defs:
raise ValueError("missing def")
elif param.ty.flag and param.ty.flag.name not in flag_defs:
raise ValueError("missing def")
params.append(param)
if ty.name in type_defs:
ty.generic_ref = True
return cls(
namespace=namespace,
name=name,
id=id,
params=params,
ty=ty,
)
@property
def full_name(self) -> str:
ns = ".".join(self.namespace) + "." if self.namespace else ""
return f"{ns}{self.name}"
def __str__(self) -> str:
res = ""
for ns in self.namespace:
res += f"{ns}."
res += f"{self.name}#{self.id:x}"
def_set: Set[str] = set()
for param in self.params:
if isinstance(param.ty, NormalParameter):
def_set.update(param.ty.ty.find_generic_refs())
type_defs = list(sorted(def_set))
for type_def in type_defs:
res += f" {{{type_def}:Type}}"
for param in self.params:
res += f" {param}"
res += f" = {self.ty}"
return res

View File

@@ -0,0 +1,23 @@
from dataclasses import dataclass
from typing import Self
@dataclass
class Flag:
name: str
index: int
@classmethod
def from_str(cls, ty: str) -> Self:
if (dot_pos := ty.find(".")) != -1:
try:
index = int(ty[dot_pos + 1 :])
except ValueError:
raise ValueError("invalid flag")
else:
return cls(name=ty[:dot_pos], index=index)
else:
raise ValueError("invalid flag")
def __str__(self) -> str:
return f"{self.name}.{self.index}"

View File

@@ -0,0 +1,40 @@
from dataclasses import dataclass
from typing import Self
from .parameter_type import BaseParameter
class TypeDefNotImplemented(NotImplementedError):
def __init__(self, name: str):
super().__init__(f"typedef not implemented: {name}")
self.name = name
@dataclass
class Parameter:
name: str
ty: BaseParameter
@classmethod
def from_str(cls, param: str) -> Self:
if param.startswith("{"):
if param.endswith(":Type}"):
raise TypeDefNotImplemented(param[1 : param.index(":")])
else:
raise ValueError("missing def")
parts = param.split(":")
if not parts:
raise ValueError("empty")
elif len(parts) == 1:
raise ValueError("not implemented")
else:
name, ty, *_ = parts
if not name:
raise ValueError("empty")
return cls(name=name, ty=BaseParameter.from_str(ty))
def __str__(self) -> str:
return f"{self.name}:{self.ty}"

View File

@@ -0,0 +1,39 @@
from abc import ABC
from dataclasses import dataclass
from typing import Optional, Union
from .flag import Flag
from .ty import Type
class BaseParameter(ABC):
@staticmethod
def from_str(ty: str) -> Union["FlagsParameter", "NormalParameter"]:
if not ty:
raise ValueError("empty")
if ty == "#":
return FlagsParameter()
if (pos := ty.find("?")) != -1:
ty, flag = ty[pos + 1 :], Flag.from_str(ty[:pos])
else:
flag = None
return NormalParameter(ty=Type.from_str(ty), flag=flag)
@dataclass
class FlagsParameter(BaseParameter):
def __str__(self) -> str:
return "#"
@dataclass
class NormalParameter(BaseParameter):
ty: Type
flag: Optional[Flag]
def __str__(self) -> str:
res = ""
if self.flag is not None:
res += f"{self.flag}?"
res += str(self.ty)
return res

View File

@@ -0,0 +1,60 @@
from dataclasses import dataclass
from typing import Iterator, List, Optional, Self
@dataclass
class Type:
namespace: List[str]
name: str
bare: bool
generic_ref: bool
generic_arg: Optional[Self]
@classmethod
def from_str(cls, ty: str) -> Self:
stripped = ty.lstrip("!")
ty, generic_ref = stripped, stripped != ty
if (pos := ty.find("<")) != -1:
if not ty.endswith(">"):
raise ValueError("invalid generic")
ty, generic_arg = ty[:pos], Type.from_str(ty[pos + 1 : -1])
else:
generic_arg = None
namespace = ty.split(".")
if not all(namespace):
raise ValueError("empty")
name = namespace.pop()
bare = name[0].islower()
return cls(
namespace=namespace,
name=name,
bare=bare,
generic_ref=generic_ref,
generic_arg=generic_arg,
)
@property
def full_name(self) -> str:
ns = ".".join(self.namespace) + "." if self.namespace else ""
return f"{ns}{self.name}"
def __str__(self) -> str:
res = ""
for ns in self.namespace:
res += f"{ns}."
if self.generic_ref:
res += "!"
res += self.name
if self.generic_arg is not None:
res += f"<{self.generic_arg}>"
return res
def find_generic_refs(self) -> Iterator[str]:
if self.generic_ref:
yield self.name
if self.generic_arg is not None:
yield from self.generic_arg.find_generic_refs()

View File

@@ -0,0 +1,47 @@
from typing import Iterator, Type
from .tl.definition import Definition
from .utils import remove_tl_comments
DEFINITION_SEP = ";"
CATEGORY_MARKER = "---"
FUNCTIONS_SEP = f"{CATEGORY_MARKER}functions---"
TYPES_SEP = f"{CATEGORY_MARKER}types---"
class TypeDef(Definition):
pass
class FunctionDef(Definition):
pass
def iterate(contents: str) -> Iterator[TypeDef | FunctionDef | Exception]:
contents = remove_tl_comments(contents)
index = 0
cls: Type[TypeDef] | Type[FunctionDef] = TypeDef
while index < len(contents):
if (end := contents.find(DEFINITION_SEP, index)) == -1:
end = len(contents)
definition = contents[index:end].strip()
index = end + len(DEFINITION_SEP)
if not definition:
continue
if definition.startswith(CATEGORY_MARKER):
if definition.startswith(FUNCTIONS_SEP):
cls = FunctionDef
definition = definition[len(FUNCTIONS_SEP) :].strip()
elif definition.startswith(TYPES_SEP):
cls = TypeDef
definition = definition[len(FUNCTIONS_SEP) :].strip()
else:
raise ValueError("bad separator")
try:
yield cls.from_str(definition)
except Exception as e:
yield e

View File

@@ -0,0 +1,20 @@
import re
import zlib
def remove_tl_comments(contents: str) -> str:
return re.sub(r"//[^\n]*(?=\n)", "", contents)
def infer_id(definition: str) -> int:
representation = (
definition.replace(":bytes ", ": string")
.replace("?bytes ", "? string")
.replace("<", " ")
.replace(">", "")
.replace("{", "")
.replace("}", "")
)
representation = re.sub(r" \w+:flags\.\d+\?true", "", representation)
return zlib.crc32(representation.encode("ascii"))

View File

@@ -0,0 +1,3 @@
from .._impl.codegen import FakeFs, ParsedTl, generate
__all__ = ["FakeFs", "ParsedTl", "generate"]

View File

@@ -0,0 +1,31 @@
import sys
from pathlib import Path
from .._impl.codegen import FakeFs, generate, load_tl_file
HELP = f"""
USAGE:
python -m {__package__} <TL_FILE> <OUT_DIR>
ARGS:
<TL_FILE>
The path to the `.tl' file to generate Python code from.
<OUT_DIR>
The directory where the generated code will be written to.
""".strip()
def main() -> None:
if len(sys.argv) != 3:
print(HELP)
sys.exit(1)
tl, out = sys.argv[1:]
fs = FakeFs()
generate(fs, load_tl_file(tl))
fs.materialize(Path(out))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,25 @@
from .._impl.tl_parser.tl.definition import Definition
from .._impl.tl_parser.tl.flag import Flag
from .._impl.tl_parser.tl.parameter import Parameter, TypeDefNotImplemented
from .._impl.tl_parser.tl.parameter_type import (
BaseParameter,
FlagsParameter,
NormalParameter,
)
from .._impl.tl_parser.tl.ty import Type
from .._impl.tl_parser.tl_iterator import FunctionDef, TypeDef
from .._impl.tl_parser.tl_iterator import iterate as parse_tl_file
__all__ = [
"Definition",
"Flag",
"Parameter",
"TypeDefNotImplemented",
"BaseParameter",
"FlagsParameter",
"NormalParameter",
"Type",
"FunctionDef",
"TypeDef",
"parse_tl_file",
]

View File

@@ -0,0 +1 @@
__version__ = "0.1.0"