Port mtproto from grammers

This commit is contained in:
Lonami Exo
2023-07-09 21:16:55 +02:00
parent 9636ef35c1
commit 269ee4f05f
35 changed files with 1747 additions and 57 deletions

View File

@@ -90,13 +90,15 @@ def decrypt_data_v2(ciphertext: bytes, auth_key: AuthKey) -> bytes:
return plaintext
def generate_key_data_from_nonce(server_nonce: bytes, new_nonce: bytes) -> CalcKey:
hash1 = sha1(new_nonce + server_nonce).digest()
hash2 = sha1(server_nonce + new_nonce).digest()
hash3 = sha1(new_nonce + new_nonce).digest()
def generate_key_data_from_nonce(server_nonce: int, new_nonce: int) -> CalcKey:
server_bytes = server_nonce.to_bytes(16)
new_bytes = new_nonce.to_bytes(32)
hash1 = sha1(new_bytes + server_bytes).digest()
hash2 = sha1(server_bytes + new_bytes).digest()
hash3 = sha1(new_bytes + new_bytes).digest()
key = hash1 + hash2[:12]
iv = hash2[12:20] + hash3 + new_nonce[:4]
iv = hash2[12:20] + hash3 + new_bytes[:4]
return CalcKey(key, iv)
@@ -108,3 +110,6 @@ def encrypt_ige(plaintext: bytes, key: bytes, iv: bytes) -> bytes:
def decrypt_ige(padded_ciphertext: bytes, key: bytes, iv: bytes) -> bytes:
return ige_decrypt(padded_ciphertext, key, iv)
__all__ = ["AuthKey", "encrypt_data_v2", "decrypt_data_v2"]

View File

@@ -19,5 +19,9 @@ class AuthKey:
def __bytes__(self) -> bytes:
return self.data
def calc_new_nonce_hash(self, new_nonce: bytes, number: int) -> bytes:
return sha1(new_nonce + bytes((number,)) + self.aux_hash).digest()[4:]
def calc_new_nonce_hash(self, new_nonce: int, number: int) -> int:
return int.from_bytes(
sha1(new_nonce.to_bytes(32) + number.to_bytes(1) + self.aux_hash).digest()[
4:
]
)

View File

@@ -1,3 +1,4 @@
import os
import struct
from hashlib import sha1
@@ -8,15 +9,21 @@ from ..tl.core import serialize_bytes_to
def compute_fingerprint(key: PublicKey) -> int:
buffer = bytearray()
serialize_bytes_to(buffer, int.to_bytes(key.n, (key.n.bit_length() + 7) // 8))
serialize_bytes_to(buffer, int.to_bytes(key.e, (key.e.bit_length() + 7) // 8))
serialize_bytes_to(buffer, key.n.to_bytes((key.n.bit_length() + 7) // 8))
serialize_bytes_to(buffer, key.e.to_bytes((key.e.bit_length() + 7) // 8))
fingerprint = struct.unpack("<q", sha1(buffer).digest()[-8:])[0]
assert isinstance(fingerprint, int)
return fingerprint
def encrypt_hashed(data: bytes, key: PublicKey) -> bytes:
return encrypt(sha1(data).digest() + data, key)
def encrypt_hashed(data: bytes, key: PublicKey, random_data: bytes) -> bytes:
# Cannot use `rsa.encrypt` because it's not deterministic and requires its own padding.
padding_length = 235 - len(data)
assert padding_length >= 0 and len(random_data) >= padding_length
to_encrypt = sha1(data).digest() + data + random_data[:padding_length]
payload = int.from_bytes(to_encrypt)
encrypted = pow(payload, key.e, key.n)
return encrypted.to_bytes(256)
# From my.telegram.org.

View File

@@ -0,0 +1,307 @@
import os
import struct
import time
from dataclasses import dataclass
from hashlib import sha1
from typing import Tuple
from telethon._impl.crypto import decrypt_ige, encrypt_ige, generate_key_data_from_nonce
from telethon._impl.crypto.auth_key import AuthKey
from telethon._impl.crypto.factorize import factorize
from telethon._impl.crypto.rsa import RSA_KEYS, encrypt_hashed
from telethon._impl.tl.core.reader import Reader
from ..tl.mtproto.abcs import ServerDhInnerData as AbcServerDhInnerData
from ..tl.mtproto.abcs import ServerDhParams, SetClientDhParamsAnswer
from ..tl.mtproto.functions import req_dh_params, req_pq_multi, set_client_dh_params
from ..tl.mtproto.types import (
ClientDhInnerData,
DhGenFail,
DhGenOk,
DhGenRetry,
PQInnerData,
ResPq,
ServerDhInnerData,
ServerDhParamsFail,
ServerDhParamsOk,
)
@dataclass
class Step1:
nonce: int
@dataclass
class Step2:
nonce: int
server_nonce: int
new_nonce: int
@dataclass
class Step3:
nonce: int
server_nonce: int
new_nonce: int
gab: int
time_offset: int
@dataclass
class CreatedKey:
auth_key: AuthKey
time_offset: int
first_salt: int
@dataclass
class DhGenData:
nonce: int
server_nonce: int
new_nonce_hash: int
nonce_number: int
def _do_step1(random_bytes: bytes) -> Tuple[bytes, Step1]:
assert len(random_bytes) == 16
nonce = int.from_bytes(random_bytes)
return req_pq_multi(nonce=nonce), Step1(nonce=nonce)
def step1() -> Tuple[bytes, Step1]:
return _do_step1(os.urandom(16))
def _do_step2(data: Step1, response: bytes, random_bytes: bytes) -> Tuple[bytes, Step2]:
assert len(random_bytes) == 288
nonce = data.nonce
res_pq = ResPq.from_bytes(response)
if len(res_pq.pq) != 8:
raise ValueError(f"invalid pq size: {len(res_pq.pq)}")
pq = struct.unpack(">Q", res_pq.pq)[0]
p, q = factorize(pq)
new_nonce = int.from_bytes(random_bytes[:32])
random_bytes = random_bytes[32:]
# https://core.telegram.org/mtproto/auth_key#dh-exchange-initiation
p_bytes = p.to_bytes((p.bit_length() + 7) // 8)
q_bytes = q.to_bytes((q.bit_length() + 7) // 8)
pq_inner_data = bytes(
PQInnerData(
pq=res_pq.pq,
p=p_bytes,
q=q_bytes,
nonce=nonce,
server_nonce=res_pq.server_nonce,
new_nonce=new_nonce,
)
)
try:
fingerprint = next(
fp for fp in res_pq.server_public_key_fingerprints if fp in RSA_KEYS
)
except StopIteration:
raise ValueError(
f"unknown fingerprints: {res_pq.server_public_key_fingerprints}"
)
key = RSA_KEYS[fingerprint]
ciphertext = encrypt_hashed(pq_inner_data, key, random_bytes)
return req_dh_params(
nonce=nonce,
server_nonce=res_pq.server_nonce,
p=p_bytes,
q=q_bytes,
public_key_fingerprint=fingerprint,
encrypted_data=ciphertext,
), Step2(nonce=nonce, server_nonce=res_pq.server_nonce, new_nonce=new_nonce)
def step2(data: Step1, response: bytes) -> Tuple[bytes, Step2]:
return _do_step2(data, response, os.urandom(288))
def _do_step3(
data: Step2, response: bytes, random_bytes: bytes, now: int
) -> Tuple[bytes, Step3]:
assert len(random_bytes) == 272
nonce = data.nonce
server_nonce = data.server_nonce
new_nonce = data.new_nonce
server_dh_params = ServerDhParams.from_bytes(response)
if isinstance(server_dh_params, ServerDhParamsFail):
check_nonce(server_dh_params.nonce, nonce)
check_server_nonce(server_dh_params.server_nonce, server_nonce)
new_nonce_hash = int.from_bytes(sha1(new_nonce.to_bytes(16)).digest()[4:])
check_new_nonce_hash(server_dh_params.new_nonce_hash, new_nonce_hash)
raise ValueError("server failed to provide dh params")
else:
assert isinstance(server_dh_params, ServerDhParamsOk)
check_nonce(server_dh_params.nonce, nonce)
check_server_nonce(server_dh_params.server_nonce, server_nonce)
if len(server_dh_params.encrypted_answer) % 16 != 0:
raise ValueError(
f"encrypted response not padded with size: {len(server_dh_params.encrypted_answer)}"
)
key, iv = generate_key_data_from_nonce(server_nonce, new_nonce)
plain_text_answer = decrypt_ige(server_dh_params.encrypted_answer, key, iv)
got_answer_hash = plain_text_answer[:20]
plain_text_reader = Reader(plain_text_answer[20:])
server_dh_inner = AbcServerDhInnerData._read_from(plain_text_reader)
assert isinstance(server_dh_inner, ServerDhInnerData)
expected_answer_hash = sha1(
plain_text_answer[20 : 20 + plain_text_reader._pos]
).digest()
if got_answer_hash != expected_answer_hash:
raise ValueError("invalid answer hash")
check_nonce(server_dh_inner.nonce, nonce)
check_server_nonce(server_dh_inner.server_nonce, server_nonce)
dh_prime = int.from_bytes(server_dh_inner.dh_prime)
g = server_dh_inner.g
g_a = int.from_bytes(server_dh_inner.g_a)
time_offset = server_dh_inner.server_time - now
b = int.from_bytes(random_bytes[:256])
g_b = pow(g, b, dh_prime)
gab = pow(g_a, b, dh_prime)
random_bytes = random_bytes[256:]
# https://core.telegram.org/mtproto/auth_key#dh-key-exchange-complete
check_g_in_range(g, 1, dh_prime - 1)
check_g_in_range(g_a, 1, dh_prime - 1)
check_g_in_range(g_b, 1, dh_prime - 1)
safety_range = 1 << (2048 - 64)
check_g_in_range(g_a, safety_range, dh_prime - safety_range)
check_g_in_range(g_b, safety_range, dh_prime - safety_range)
client_dh_inner = bytes(
ClientDhInnerData(
nonce=nonce,
server_nonce=server_nonce,
retry_id=0, # TODO use an actual retry_id
g_b=g_b.to_bytes((g_b.bit_length() + 7) // 8),
)
)
client_dh_inner_hashed = sha1(client_dh_inner).digest() + client_dh_inner
client_dh_inner_hashed += random_bytes[
: (16 - (len(client_dh_inner_hashed) % 16)) % 16
]
client_dh_encrypted = encrypt_ige(client_dh_inner_hashed, key, iv)
return set_client_dh_params(
nonce=nonce, server_nonce=server_nonce, encrypted_data=client_dh_encrypted
), Step3(
nonce=nonce,
server_nonce=server_nonce,
new_nonce=new_nonce,
gab=gab,
time_offset=time_offset,
)
def step3(data: Step2, response: bytes) -> Tuple[bytes, Step3]:
return _do_step3(data, response, os.urandom(272), int(time.time()))
def create_key(data: Step3, response: bytes) -> CreatedKey:
nonce = data.nonce
server_nonce = data.server_nonce
new_nonce = data.new_nonce
gab = data.gab
time_offset = data.time_offset
dh_gen_answer = SetClientDhParamsAnswer.from_bytes(response)
if isinstance(dh_gen_answer, DhGenOk):
dh_gen = DhGenData(
nonce=dh_gen_answer.nonce,
server_nonce=dh_gen_answer.server_nonce,
new_nonce_hash=dh_gen_answer.new_nonce_hash1,
nonce_number=1,
)
elif isinstance(dh_gen_answer, DhGenRetry):
dh_gen = DhGenData(
nonce=dh_gen_answer.nonce,
server_nonce=dh_gen_answer.server_nonce,
new_nonce_hash=dh_gen_answer.new_nonce_hash2,
nonce_number=2,
)
elif isinstance(dh_gen_answer, DhGenFail):
dh_gen = DhGenData(
nonce=dh_gen_answer.nonce,
server_nonce=dh_gen_answer.server_nonce,
new_nonce_hash=dh_gen_answer.new_nonce_hash3,
nonce_number=3,
)
else:
raise ValueError(f"unknown dh gen answer type: {dh_gen_answer}")
check_nonce(dh_gen.nonce, nonce)
check_server_nonce(dh_gen.server_nonce, server_nonce)
auth_key = AuthKey.from_bytes(gab.to_bytes(256))
new_nonce_hash = auth_key.calc_new_nonce_hash(new_nonce, dh_gen.nonce_number)
check_new_nonce_hash(dh_gen.new_nonce_hash, new_nonce_hash)
first_salt = struct.unpack(
"<q",
bytes(
a ^ b
for a, b in zip(new_nonce.to_bytes(32)[:8], server_nonce.to_bytes(16)[:8])
),
)[0]
if dh_gen.nonce_number == 1:
return CreatedKey(
auth_key=auth_key,
time_offset=time_offset,
first_salt=first_salt,
)
else:
raise ValueError("dh gen fail")
def check_nonce(got: int, expected: int) -> None:
if got != expected:
raise ValueError(f"invalid nonce, expected: {expected}, got: {got}")
def check_server_nonce(got: int, expected: int) -> None:
if got != expected:
raise ValueError(f"invalid server nonce, expected: {expected}, got: {got}")
def check_new_nonce_hash(got: int, expected: int) -> None:
if got != expected:
raise ValueError(f"invalid new nonce, expected: {expected}, got: {got}")
def check_g_in_range(value: int, low: int, high: int) -> None:
if not (low < value < high):
raise ValueError(f"g parameter {value} not in range({low+1}, {high})")

View File

@@ -0,0 +1,12 @@
from .encrypted import Encrypted
from .plain import Plain
from .types import Deserialization, MsgId, Mtp, RpcError
__all__ = [
"Encrypted",
"Plain",
"Deserialization",
"MsgId",
"Mtp",
"RpcError",
]

View File

@@ -0,0 +1,409 @@
import os
import struct
import time
from typing import List, Optional, Tuple, Union
from ...crypto import AuthKey, decrypt_data_v2, encrypt_data_v2
from ...tl.mtproto.abcs import BadMsgNotification as AbcBadMsgNotification
from ...tl.mtproto.abcs import DestroySessionRes
from ...tl.mtproto.abcs import MsgDetailedInfo as AbcMsgDetailedInfo
from ...tl.mtproto.functions import get_future_salts
from ...tl.mtproto.types import (
BadMsgNotification,
BadServerSalt,
DestroySessionNone,
DestroySessionOk,
FutureSalt,
FutureSalts,
GzipPacked,
HttpWait,
Message,
MsgContainer,
MsgDetailedInfo,
MsgNewDetailedInfo,
MsgResendReq,
MsgsAck,
MsgsAllInfo,
MsgsStateInfo,
MsgsStateReq,
NewSessionCreated,
Pong,
RpcAnswerDropped,
RpcAnswerDroppedRunning,
RpcAnswerUnknown,
)
from ...tl.mtproto.types import RpcError as GeneratedRpcError
from ...tl.mtproto.types import RpcResult
from ...tl.types import (
Updates,
UpdatesCombined,
UpdateShort,
UpdateShortChatMessage,
UpdateShortMessage,
UpdateShortSentMessage,
UpdatesTooLong,
)
from ..utils import (
CONTAINER_MAX_LENGTH,
CONTAINER_MAX_SIZE,
DEFAULT_COMPRESSION_THRESHOLD,
MESSAGE_SIZE_OVERHEAD,
check_message_buffer,
gzip_compress,
gzip_decompress,
message_requires_ack,
)
from .types import Deserialization, MsgId, Mtp, RpcError
NUM_FUTURE_SALTS = 64
SALT_USE_DELAY = 60
UPDATE_IDS = {
Updates.constructor_id(),
UpdatesCombined.constructor_id(),
UpdateShort.constructor_id(),
UpdateShortChatMessage.constructor_id(),
UpdateShortMessage.constructor_id(),
UpdateShortSentMessage.constructor_id(),
UpdatesTooLong.constructor_id(),
}
HEADER_LEN = 8 + 8 # salt, client_id
CONTAINER_HEADER_LEN = (8 + 4 + 4) + (4 + 4) # msg_id, seq_no, size, constructor, len
class Encrypted(Mtp):
def __init__(
self,
auth_key: bytes,
*,
time_offset: Optional[int] = None,
first_salt: Optional[int] = None,
compression_threshold: Optional[int] = DEFAULT_COMPRESSION_THRESHOLD,
) -> None:
self._auth_key: AuthKey = AuthKey.from_bytes(auth_key)
self._time_offset: int = time_offset or 0
self._salts: List[FutureSalt] = [
FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=first_salt or 0)
]
self._start_salt_time: Optional[Tuple[int, int]] = None
self._client_id: int = struct.unpack("<q", os.urandom(8))[0]
self._sequence: int = 0
self._last_msg_id: int = 0
self._pending_ack: List[int] = []
self._compression_threshold = compression_threshold
self._rpc_results: List[Tuple[MsgId, Union[bytes, ValueError]]] = []
self._updates: List[bytes] = []
self._buffer = bytearray()
self._msg_count: int = 0
self._handlers = {
RpcResult.constructor_id(): self._handle_rpc_result,
MsgsAck.constructor_id(): self._handle_ack,
BadMsgNotification.constructor_id(): self._handle_bad_notification,
BadServerSalt.constructor_id(): self._handle_bad_notification,
MsgsStateReq.constructor_id(): self._handle_state_req,
MsgsStateInfo.constructor_id(): self._handle_state_info,
MsgsAllInfo.constructor_id(): self._handle_msg_all,
MsgDetailedInfo.constructor_id(): self._handle_detailed_info,
MsgNewDetailedInfo.constructor_id(): self._handle_detailed_info,
MsgResendReq.constructor_id(): self._handle_msg_resend,
FutureSalt.constructor_id(): self._handle_future_salt,
FutureSalts.constructor_id(): self._handle_future_salts,
Pong.constructor_id(): self._handle_pong,
DestroySessionOk.constructor_id(): self._handle_destroy_session,
DestroySessionNone.constructor_id(): self._handle_destroy_session,
NewSessionCreated.constructor_id(): self._handle_new_session_created,
GzipPacked.constructor_id(): self._handle_gzip_packed,
HttpWait.constructor_id(): self._handle_http_wait,
}
@property
def auth_key(self) -> bytes:
return self._auth_key.data
def _correct_time_offset(self, msg_id: int) -> None:
now = time.time()
correct = msg_id >> 32
self._time_offset = correct - int(now)
def _get_new_msg_id(self) -> int:
now = time.time()
new_msg_id = int((now + self._time_offset) * 0x100000000)
if self._last_msg_id >= new_msg_id:
new_msg_id = self._last_msg_id + 4
self._last_msg_id = new_msg_id
return new_msg_id
def _get_seq_no(self, content_related: bool) -> int:
if content_related:
self._sequence += 2
return self._sequence - 1
else:
return self._sequence
def _serialize_msg(self, body: bytes, content_related: bool) -> MsgId:
msg_id = self._get_new_msg_id()
seq_no = self._get_seq_no(content_related)
self._buffer += struct.pack("<qii", msg_id, seq_no, len(body))
self._buffer += body
self._msg_count += 1
return MsgId(msg_id)
def _finalize_plain(self) -> bytes:
if not self._msg_count:
return b""
if self._msg_count == 1:
del self._buffer[:CONTAINER_HEADER_LEN]
self._buffer[:HEADER_LEN] = struct.pack(
"<qq", self._salts[-1].salt if self._salts else 0, self._client_id
)
if self._msg_count != 1:
self._buffer[HEADER_LEN : HEADER_LEN + CONTAINER_HEADER_LEN] = struct.pack(
"<qiiIi",
self._get_new_msg_id(),
self._get_seq_no(False),
len(self._buffer) - HEADER_LEN - CONTAINER_HEADER_LEN + 8,
MsgContainer.constructor_id(),
self._msg_count,
)
self._msg_count = 0
result = bytes(self._buffer)
self._buffer.clear()
return result
def _process_message(self, message: Message) -> None:
if message_requires_ack(message):
self._pending_ack.append(message.msg_id)
# https://core.telegram.org/mtproto/service_messages
# https://core.telegram.org/mtproto/service_messages_about_messages
# TODO verify what needs ack and what doesn't
constructor_id = struct.unpack_from("<I", message.body)[0]
self._handlers.get(constructor_id, self._handle_update)(message)
def _handle_rpc_result(self, message: Message) -> None:
assert isinstance(message.body, RpcResult)
req_msg_id = message.body.req_msg_id
result = message.body.result
msg_id = MsgId(req_msg_id)
inner_constructor = struct.unpack_from("<I", result)[0]
if inner_constructor == GeneratedRpcError.constructor_id():
self._rpc_results.append(
(
msg_id,
RpcError.from_mtproto_error(GeneratedRpcError.from_bytes(result)),
)
)
elif inner_constructor == RpcAnswerUnknown.constructor_id():
pass # msg_id = rpc_drop_answer.msg_id
elif inner_constructor == RpcAnswerDroppedRunning.constructor_id():
pass # msg_id = rpc_drop_answer.msg_id, original_request.msg_id
elif inner_constructor == RpcAnswerDropped.constructor_id():
pass # dropped
elif inner_constructor == GzipPacked.constructor_id():
body = gzip_decompress(GzipPacked.from_bytes(result))
self._store_own_updates(body)
self._rpc_results.append((msg_id, body))
else:
self._store_own_updates(result)
self._rpc_results.append((msg_id, result))
def _store_own_updates(self, body: bytes) -> None:
constructor_id = struct.unpack_from("I", body)[0]
if constructor_id in UPDATE_IDS:
self._updates.append(body)
def _handle_ack(self, message: Message) -> None:
# TODO notify about this somehow
MsgsAck.from_bytes(message.body)
def _handle_bad_notification(self, message: Message) -> None:
# TODO notify about this somehow
bad_msg = AbcBadMsgNotification.from_bytes(message.body)
if isinstance(bad_msg, BadServerSalt):
self._rpc_results.append(
(
MsgId(bad_msg.bad_msg_id),
ValueError(f"bad msg: {bad_msg.error_code}"),
)
)
self._salts.clear()
self._salts.append(
FutureSalt(
valid_since=0, valid_until=0x7FFFFFFF, salt=bad_msg.new_server_salt
)
)
self.push(get_future_salts(num=NUM_FUTURE_SALTS))
return
assert isinstance(bad_msg, BadMsgNotification)
self._rpc_results.append(
(MsgId(bad_msg.bad_msg_id), ValueError(f"bad msg: {bad_msg.error_code}"))
)
if bad_msg.error_code in (16, 17):
self._correct_time_offset(message.msg_id)
elif bad_msg.error_code == 32:
# TODO start with a fresh session rather than guessing
self._sequence += 64
elif bad_msg.error_code == 33:
# TODO start with a fresh session rather than guessing
self._sequence -= 16
def _handle_state_req(self, message: Message) -> None:
# TODO implement
MsgsStateReq.from_bytes(message.body)
def _handle_state_info(self, message: Message) -> None:
# TODO implement
MsgsStateInfo.from_bytes(message.body)
def _handle_msg_all(self, message: Message) -> None:
# TODO implement
MsgsAllInfo.from_bytes(message.body)
def _handle_detailed_info(self, message: Message) -> None:
# TODO properly implement
msg_detailed = AbcMsgDetailedInfo.from_bytes(message.body)
if isinstance(msg_detailed, MsgDetailedInfo):
self._pending_ack.append(msg_detailed.answer_msg_id)
elif isinstance(msg_detailed, MsgNewDetailedInfo):
self._pending_ack.append(msg_detailed.answer_msg_id)
else:
assert False
def _handle_msg_resend(self, message: Message) -> None:
# TODO implement
MsgResendReq.from_bytes(message.body)
def _handle_future_salts(self, message: Message) -> None:
# TODO implement
salts = FutureSalts.from_bytes(message.body)
self._rpc_results.append((MsgId(salts.req_msg_id), message.body))
self._start_salt_time = (salts.now, int(time.time()))
self._salts = salts.salts
self._salts.sort(key=lambda salt: -salt.valid_since)
def _handle_future_salt(self, message: Message) -> None:
FutureSalt.from_bytes(message.body)
assert False # no request should cause this
def _handle_pong(self, message: Message) -> None:
pong = Pong.from_bytes(message.body)
self._rpc_results.append((MsgId(pong.msg_id), message.body))
def _handle_destroy_session(self, message: Message) -> None:
# TODO implement
DestroySessionRes.from_bytes(message.body)
def _handle_new_session_created(self, message: Message) -> None:
# TODO implement
new_session = NewSessionCreated.from_bytes(message.body)
self._salts.clear()
self._salts.append(
FutureSalt(
valid_since=0, valid_until=0x7FFFFFFF, salt=new_session.server_salt
)
)
def _handle_container(self, message: Message) -> None:
container = MsgContainer.from_bytes(message.body)
for inner_message in container.messages:
self._process_message(inner_message)
def _handle_gzip_packed(self, message: Message) -> None:
container = GzipPacked.from_bytes(message.body)
inner_body = gzip_decompress(container)
self._process_message(
Message(
msg_id=message.msg_id,
seqno=message.seqno,
bytes=len(inner_body),
body=inner_body,
)
)
def _handle_http_wait(self, message: Message) -> None:
# TODO implement
HttpWait.from_bytes(message.body)
def _handle_update(self, message: Message) -> None:
# TODO if this `Updates` cannot be deserialized, `getDifference` should be used
self._updates.append(message.body)
def push(self, request: bytes) -> Optional[MsgId]:
if not self._buffer:
# Reserve space for `finalize`
self._buffer += bytes(HEADER_LEN + CONTAINER_HEADER_LEN)
if self._pending_ack:
self._serialize_msg(bytes(MsgsAck(msg_ids=self._pending_ack)), False)
self._pending_ack = []
if self._start_salt_time:
start_secs, start_instant = self._start_salt_time
if len(self._salts) >= 2:
salt = self._salts[-2]
now = start_secs + (start_instant - int(time.time()))
if now >= salt.valid_since + SALT_USE_DELAY:
self._salts.pop()
if len(self._salts) == 1:
self._serialize_msg(
bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True
)
if self._msg_count == CONTAINER_MAX_LENGTH:
return None
assert len(request) + MESSAGE_SIZE_OVERHEAD <= CONTAINER_MAX_SIZE
assert len(request) % 4 == 0
body = request
if self._compression_threshold is not None:
if len(request) >= self._compression_threshold:
compressed = bytes(GzipPacked(packed_data=gzip_compress(request)))
if len(compressed) < len(request):
body = compressed
new_size = len(self._buffer) + len(body) + MESSAGE_SIZE_OVERHEAD
if new_size >= CONTAINER_MAX_SIZE:
return None
return self._serialize_msg(body, True)
def finalize(self) -> bytes:
buffer = self._finalize_plain()
if not buffer:
return buffer
else:
return encrypt_data_v2(buffer, self._auth_key)
def deserialize(self, payload: bytes) -> Deserialization:
check_message_buffer(payload)
plaintext = decrypt_data_v2(payload, self._auth_key)
_, client_id = struct.unpack_from("<qq", plaintext) # salt, client_id
if client_id != self._client_id:
raise RuntimeError("wrong session id")
self._process_message(Message.from_bytes(memoryview(plaintext)[16:]))
result = Deserialization(rpc_results=self._rpc_results, updates=self._updates)
self._rpc_results = []
self._updates = []
return result

View File

@@ -0,0 +1,52 @@
import struct
from typing import Optional
from ..utils import check_message_buffer
from .types import Deserialization, MsgId, Mtp
class Plain(Mtp):
def __init__(self) -> None:
self._buffer = bytearray()
# https://core.telegram.org/mtproto/description#unencrypted-message
def push(self, request: bytes) -> Optional[MsgId]:
if self._buffer:
return None
# https://core.telegram.org/mtproto/samples-auth_key seems to
# imply a need to generate a valid `message_id`, but 0 works too.
msg_id = MsgId(0)
# auth_key_id = 0, message_id, message_data_length.
self._buffer += struct.pack("<qqi", 0, msg_id, len(request))
self._buffer += request # message_data
return msg_id
def finalize(self) -> bytes:
result = bytes(self._buffer)
self._buffer.clear()
return result
def deserialize(self, payload: bytes) -> Deserialization:
check_message_buffer(payload)
auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload)
if auth_key_id != 0:
raise ValueError(f"bad auth key, expected: 0, got: {auth_key_id}")
# https://core.telegram.org/mtproto/description#message-identifier-msg-id
if msg_id <= 0 or (msg_id % 4) != 1:
raise ValueError(f"bad msg id, got: {msg_id}")
if length < 0:
raise ValueError(f"bad length: expected >= 0, got: {length}")
if 20 + length > len(payload):
raise ValueError(
f"message too short, expected: {20 + length}, got {len(payload)}"
)
return Deserialization(
rpc_results=[(MsgId(0), payload[20 : 20 + length])], updates=[]
)

View File

@@ -0,0 +1,76 @@
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, NewType, Optional, Self, Tuple, Union
from ...tl.mtproto.types import RpcError as GeneratedRpcError
MsgId = NewType("MsgId", int)
@dataclass
class Deserialization:
rpc_results: List[Tuple[MsgId, Union[bytes, ValueError]]]
updates: List[bytes]
class RpcError(ValueError):
def __init__(
self,
*,
code: int = 0,
name: str = "",
value: Optional[int] = None,
caused_by: Optional[int] = None,
) -> None:
append_value = f" ({value})" if value else None
super().__init__(f"rpc error {code}: {name}{append_value}")
self.code = code
self.name = name
self.value = value
self.caused_by = caused_by
@classmethod
def from_mtproto_error(cls, error: GeneratedRpcError) -> Self:
if m := re.search(r"-?\d+", error.error_message):
name = re.sub(
r"_{2,}",
"_",
error.error_message[: m.start()] + error.error_message[m.end() :],
).strip("_")
value = int(m[0])
else:
name = error.error_message
value = None
return cls(
code=error.error_code,
name=name,
value=value,
caused_by=None,
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return (
self.code == other.code
and self.name == other.name
and self.value == other.value
)
# https://core.telegram.org/mtproto/description
class Mtp(ABC):
@abstractmethod
def push(self, request: bytes) -> Optional[MsgId]:
pass
@abstractmethod
def finalize(self) -> bytes:
pass
@abstractmethod
def deserialize(self, payload: bytes) -> Deserialization:
pass

View File

@@ -0,0 +1,3 @@
from .abcs import Transport
__all__ = ["Transport"]

View File

@@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
class Transport(ABC):
@abstractmethod
def pack(self, input: bytes, output: bytearray) -> None:
pass
@abstractmethod
def unpack(self, input: bytes, output: bytearray) -> None:
pass

View File

@@ -0,0 +1,58 @@
import struct
from .abcs import Transport
class Abridged(Transport):
__slots__ = ("_init",)
"""
Implementation of the [abridged transport]:
```text
+----+----...----+
| len| payload |
+----+----...----+
^^^^ 1 or 4 bytes
```
[abridged transport]: https://core.telegram.org/mtproto/mtproto-transports#abridged
"""
def __init__(self) -> None:
self._init = False
def pack(self, input: bytes, output: bytearray) -> None:
assert len(input) % 4 == 0
if not self._init:
output += b"\xef"
self._init = True
length = len(input) // 4
if length < 127:
output += struct.pack("<b", length)
else:
output += struct.pack("<i", 0x7F | (length << 8))
output += input
def unpack(self, input: bytes, output: bytearray) -> None:
if not input:
raise ValueError("missing bytes, expected: 1, got: 0")
length = input[0]
if length < 127:
header_len = 1
elif len(input) < 4:
raise ValueError(f"missing bytes, expected: 4, got: {len(input)}")
else:
header_len = 4
length = struct.unpack_from("<i", input)[0] >> 8
length *= 4
if len(input) < header_len + length:
raise ValueError(
f"missing bytes, expected: {header_len + length}, got: {len(input)}"
)
output += memoryview(input)[header_len : header_len + length]

View File

@@ -0,0 +1,57 @@
import struct
from zlib import crc32
from .abcs import Transport
class Full(Transport):
__slots__ = ("_send_seq", "_recv_seq")
"""
Implementation of the [full transport]:
```text
+----+----+----...----+----+
| len| seq| payload | crc|
+----+----+----...----+----+
^^^^ 4 bytes
```
[full transport]: https://core.telegram.org/mtproto/mtproto-transports#full
"""
def __init__(self) -> None:
self._send_seq = 0
self._recv_seq = 0
def pack(self, input: bytes, output: bytearray) -> None:
assert len(input) % 4 == 0
length = len(input) + 12
output += struct.pack("<ii", length, self._send_seq)
output += input
output += struct.pack("<i", crc32(memoryview(output)[-(length - 4) :]))
self._send_seq += 1
def unpack(self, input: bytes, output: bytearray) -> None:
if len(input) < 4:
raise ValueError(f"missing bytes, expected: 4, got: {len(input)}")
length = struct.unpack_from("<i", input)[0]
if length < 12:
raise ValueError(f"bad length, expected > 12, got: {length}")
if len(input) < length:
raise ValueError(f"missing bytes, expected: {length}, got: {len(input)}")
seq = struct.unpack_from("<i", input, 4)[0]
if seq != self._recv_seq:
raise ValueError(f"bad seq, expected: {self._recv_seq}, got: {seq}")
crc = struct.unpack_from("<I", input, length - 4)[0]
valid_crc = crc32(memoryview(input)[:-4])
if crc != valid_crc:
raise ValueError(f"bad crc, expected: {valid_crc}, got: {crc}")
self._recv_seq += 1
output += memoryview(input)[8:-4]

View File

@@ -0,0 +1,43 @@
import struct
from .abcs import Transport
class Intermediate(Transport):
__slots__ = ("_init",)
"""
Implementation of the [intermediate transport]:
```text
+----+----...----+
| len| payload |
+----+----...----+
^^^^ 4 bytes
```
[intermediate transport]: https://core.telegram.org/mtproto/mtproto-transports#intermediate
"""
def __init__(self) -> None:
self._init = False
def pack(self, input: bytes, output: bytearray) -> None:
assert len(input) % 4 == 0
if not self._init:
output += b"\xee\xee\xee\xee"
self._init = True
output += struct.pack("<i", len(input))
output += input
def unpack(self, input: bytes, output: bytearray) -> None:
if len(input) < 4:
raise ValueError(f"missing bytes, expected: {4}, got: {len(input)}")
length = struct.unpack_from("<i", input)[0]
if len(input) < length:
raise ValueError(f"missing bytes, expected: {length}, got: {len(input)}")
output += memoryview(input)[4 : 4 + length]

View File

@@ -0,0 +1,34 @@
import gzip
import struct
from typing import Optional
from ..tl.mtproto.types import GzipPacked, Message
DEFAULT_COMPRESSION_THRESHOLD: Optional[int] = 512
CONTAINER_SIZE_OVERHEAD = 4 + 4 # constructor_id, inner vec length
CONTAINER_MAX_SIZE = 1_044_456 - CONTAINER_SIZE_OVERHEAD
CONTAINER_MAX_LENGTH = 100
MESSAGE_SIZE_OVERHEAD = 8 + 4 + 4 # msg_id, seq_no, bytes
def check_message_buffer(message: bytes) -> None:
if len(message) == 4:
neg_http_code = struct.unpack("<i", message)[0]
raise ValueError(f"transport error: {neg_http_code}")
elif len(message) < 20:
raise ValueError(
f"server payload is too small to be a valid message: {message.hex()}"
)
# https://core.telegram.org/mtproto/description#content-related-message
def message_requires_ack(message: Message) -> bool:
return message.seqno % 2 == 1
def gzip_decompress(gzip_packed: GzipPacked) -> bytes:
return gzip.decompress(gzip_packed.packed_data)
def gzip_compress(unpacked_data: bytes) -> bytes:
return gzip.compress(unpacked_data)

View File

@@ -29,35 +29,41 @@ def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]:
class Reader:
__slots__ = ("_buffer", "_pos", "_view")
__slots__ = ("_view", "_pos", "_len")
def __init__(self, buffer: bytes) -> None:
self._buffer = buffer
self._view = (
memoryview(buffer) if not isinstance(buffer, memoryview) else buffer
)
self._pos = 0
self._view = memoryview(self._buffer)
self._len = len(self._view)
def read_remaining(self) -> bytes:
return self.read(self._len - self._pos)
def read(self, n: int) -> bytes:
self._pos += n
return self._view[self._pos - n : n]
assert self._pos <= self._len
return self._view[self._pos - n : self._pos]
def read_fmt(self, fmt: str, size: int) -> tuple[Any, ...]:
assert struct.calcsize(fmt) == size
self._pos += size
assert self._pos <= self._len
return struct.unpack(fmt, self._view[self._pos - size : self._pos])
def read_bytes(self) -> bytes:
if self._buffer[self._pos] == 254:
if self._view[self._pos] == 254:
self._pos += 4
(length,) = struct.unpack(
"<i", self._buffer[self._pos - 3 : self._pos] + b"\0"
)
length = struct.unpack("<i", self._view[self._pos - 4 : self._pos])[0] >> 8
padding = length % 4
else:
length = self._buffer[self._pos]
length = self._view[self._pos]
padding = (length + 1) % 4
self._pos += 1
self._pos += length
assert self._pos <= self._len
data = self._view[self._pos - length : self._pos]
if padding > 0:
self._pos += 4 - padding
@@ -72,6 +78,7 @@ class Reader:
# Unfortunately `typing.cast` would add a tiny amount of runtime overhead
# which cannot be removed with optimization enabled.
self._pos += 4
assert self._pos <= self._len
cid = struct.unpack("<I", self._view[self._pos - 4 : self._pos])[0]
ty = self._get_ty(cid)
if ty is None:

View File

@@ -1,16 +1,13 @@
import struct
class Request:
__slots__ = "_body"
def __init__(self, body: bytes):
self._body = body
class Request(bytes):
__slots__ = ()
@property
def constructor_id(self) -> int:
try:
cid = struct.unpack("<i", self._body[:4])[0]
cid = struct.unpack("<i", self[:4])[0]
assert isinstance(cid, int)
return cid
except struct.error:

View File

@@ -35,7 +35,12 @@ class Serializable(abc.ABC):
return bytes(buffer)
def __repr__(self) -> str:
attrs = ", ".join(repr(getattr(self, attr)) for attr in self.__slots__)
fields = ((attr, getattr(self, attr)) for attr in self.__slots__)
fields = (
(name, bytes(field) if isinstance(field, memoryview) else field)
for name, field in fields
)
attrs = ", ".join(f"{name}={field!r}" for name, field in fields)
return f"{self.__class__.__name__}({attrs})"
def __eq__(self, other: object) -> bool: