mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-06-18 11:06:39 +00:00
Fix handling of salts and container buffer
This commit is contained in:
parent
6ed279e773
commit
c91ce98a25
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
@ -89,16 +90,12 @@ class Encrypted(Mtp):
|
|||||||
self._salts: List[FutureSalt] = [
|
self._salts: List[FutureSalt] = [
|
||||||
FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=first_salt or 0)
|
FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=first_salt or 0)
|
||||||
]
|
]
|
||||||
self._start_salt_time: Optional[Tuple[int, int]] = None
|
self._start_salt_time: Optional[Tuple[int, float]] = None
|
||||||
self._client_id: int = struct.unpack("<q", os.urandom(8))[0]
|
|
||||||
self._sequence: int = 0
|
|
||||||
self._last_msg_id: int = 0
|
|
||||||
self._pending_ack: List[int] = []
|
|
||||||
self._compression_threshold = compression_threshold
|
self._compression_threshold = compression_threshold
|
||||||
self._rpc_results: List[Tuple[MsgId, RpcResult]] = []
|
self._rpc_results: List[Tuple[MsgId, RpcResult]] = []
|
||||||
self._updates: List[bytes] = []
|
self._updates: List[bytes] = []
|
||||||
self._buffer = bytearray()
|
self._buffer = bytearray()
|
||||||
self._msg_count: int = 0
|
self._salt_request_msg_id: Optional[int] = None
|
||||||
|
|
||||||
self._handlers = {
|
self._handlers = {
|
||||||
GeneratedRpcResult.constructor_id(): self._handle_rpc_result,
|
GeneratedRpcResult.constructor_id(): self._handle_rpc_result,
|
||||||
@ -122,6 +119,13 @@ class Encrypted(Mtp):
|
|||||||
HttpWait.constructor_id(): self._handle_http_wait,
|
HttpWait.constructor_id(): self._handle_http_wait,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self._client_id: int
|
||||||
|
self._sequence: int
|
||||||
|
self._last_msg_id: int
|
||||||
|
self._pending_ack: List[int] = []
|
||||||
|
self._msg_count: int
|
||||||
|
self._reset_session()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def auth_key(self) -> bytes:
|
def auth_key(self) -> bytes:
|
||||||
return self._auth_key.data
|
return self._auth_key.data
|
||||||
@ -131,10 +135,18 @@ class Encrypted(Mtp):
|
|||||||
correct = msg_id >> 32
|
correct = msg_id >> 32
|
||||||
self._time_offset = correct - int(now)
|
self._time_offset = correct - int(now)
|
||||||
|
|
||||||
def _get_new_msg_id(self) -> int:
|
def _adjusted_now(self) -> float:
|
||||||
now = time.time()
|
return time.time() + self._time_offset
|
||||||
|
|
||||||
new_msg_id = int((now + self._time_offset) * 0x100000000)
|
def _reset_session(self) -> None:
|
||||||
|
self._client_id = struct.unpack("<q", os.urandom(8))[0]
|
||||||
|
self._sequence = 0
|
||||||
|
self._last_msg_id = 0
|
||||||
|
self._pending_ack.clear()
|
||||||
|
self._msg_count = 0
|
||||||
|
|
||||||
|
def _get_new_msg_id(self) -> int:
|
||||||
|
new_msg_id = int(self._adjusted_now() * 0x100000000)
|
||||||
if self._last_msg_id >= new_msg_id:
|
if self._last_msg_id >= new_msg_id:
|
||||||
new_msg_id = self._last_msg_id + 4
|
new_msg_id = self._last_msg_id + 4
|
||||||
|
|
||||||
@ -149,6 +161,10 @@ class Encrypted(Mtp):
|
|||||||
return self._sequence
|
return self._sequence
|
||||||
|
|
||||||
def _serialize_msg(self, body: bytes, content_related: bool) -> MsgId:
|
def _serialize_msg(self, body: bytes, content_related: bool) -> MsgId:
|
||||||
|
if not self._buffer:
|
||||||
|
# Reserve space for `finalize`
|
||||||
|
self._buffer += bytes(HEADER_LEN + CONTAINER_HEADER_LEN)
|
||||||
|
|
||||||
msg_id = self._get_new_msg_id()
|
msg_id = self._get_new_msg_id()
|
||||||
seq_no = self._get_seq_no(content_related)
|
seq_no = self._get_seq_no(content_related)
|
||||||
self._buffer += struct.pack("<qii", msg_id, seq_no, len(body))
|
self._buffer += struct.pack("<qii", msg_id, seq_no, len(body))
|
||||||
@ -156,6 +172,9 @@ class Encrypted(Mtp):
|
|||||||
self._msg_count += 1
|
self._msg_count += 1
|
||||||
return MsgId(msg_id)
|
return MsgId(msg_id)
|
||||||
|
|
||||||
|
def _get_current_salt(self) -> int:
|
||||||
|
return self._salts[-1].salt if self._salts else 0
|
||||||
|
|
||||||
def _finalize_plain(self) -> bytes:
|
def _finalize_plain(self) -> bytes:
|
||||||
if not self._msg_count:
|
if not self._msg_count:
|
||||||
return b""
|
return b""
|
||||||
@ -164,7 +183,7 @@ class Encrypted(Mtp):
|
|||||||
del self._buffer[:CONTAINER_HEADER_LEN]
|
del self._buffer[:CONTAINER_HEADER_LEN]
|
||||||
|
|
||||||
self._buffer[:HEADER_LEN] = struct.pack(
|
self._buffer[:HEADER_LEN] = struct.pack(
|
||||||
"<qq", self._salts[-1].salt if self._salts else 0, self._client_id
|
"<qq", self._get_current_salt(), self._client_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._msg_count != 1:
|
if self._msg_count != 1:
|
||||||
@ -177,6 +196,7 @@ class Encrypted(Mtp):
|
|||||||
self._msg_count,
|
self._msg_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print("packed", self._msg_count)
|
||||||
self._msg_count = 0
|
self._msg_count = 0
|
||||||
result = bytes(self._buffer)
|
result = bytes(self._buffer)
|
||||||
self._buffer.clear()
|
self._buffer.clear()
|
||||||
@ -230,37 +250,28 @@ class Encrypted(Mtp):
|
|||||||
|
|
||||||
def _handle_bad_notification(self, message: Message) -> None:
|
def _handle_bad_notification(self, message: Message) -> None:
|
||||||
bad_msg = AbcBadMsgNotification.from_bytes(message.body)
|
bad_msg = AbcBadMsgNotification.from_bytes(message.body)
|
||||||
if isinstance(bad_msg, BadServerSalt):
|
assert isinstance(bad_msg, (BadServerSalt, BadMsgNotification))
|
||||||
self._rpc_results.append(
|
|
||||||
(
|
|
||||||
MsgId(bad_msg.bad_msg_id),
|
|
||||||
BadMessage(code=bad_msg.error_code),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
exc = BadMessage(code=bad_msg.error_code)
|
||||||
|
self._rpc_results.append((MsgId(bad_msg.bad_msg_id), exc))
|
||||||
|
if isinstance(bad_msg, BadServerSalt) and self._get_current_salt() == 0:
|
||||||
|
# If we had no valid salt, this error is expected.
|
||||||
|
exc.severity = logging.INFO
|
||||||
|
|
||||||
|
if isinstance(bad_msg, BadServerSalt):
|
||||||
self._salts.clear()
|
self._salts.clear()
|
||||||
self._salts.append(
|
self._salts.append(
|
||||||
FutureSalt(
|
FutureSalt(
|
||||||
valid_since=0, valid_until=0x7FFFFFFF, salt=bad_msg.new_server_salt
|
valid_since=0, valid_until=0x7FFFFFFF, salt=bad_msg.new_server_salt
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self._salt_request_msg_id = None
|
||||||
self.push(get_future_salts(num=NUM_FUTURE_SALTS))
|
elif bad_msg.error_code not in (16, 17):
|
||||||
return
|
|
||||||
|
|
||||||
assert isinstance(bad_msg, BadMsgNotification)
|
|
||||||
self._rpc_results.append(
|
|
||||||
(MsgId(bad_msg.bad_msg_id), BadMessage(code=bad_msg.error_code))
|
|
||||||
)
|
|
||||||
|
|
||||||
if bad_msg.error_code in (16, 17):
|
|
||||||
self._correct_time_offset(message.msg_id)
|
self._correct_time_offset(message.msg_id)
|
||||||
elif bad_msg.error_code == 32:
|
elif bad_msg.error_code in (32, 33):
|
||||||
# TODO start with a fresh session rather than guessing
|
self._reset_session()
|
||||||
self._sequence += 64
|
else:
|
||||||
elif bad_msg.error_code == 33:
|
raise exc
|
||||||
# TODO start with a fresh session rather than guessing
|
|
||||||
self._sequence -= 16
|
|
||||||
|
|
||||||
def _handle_state_req(self, message: Message) -> None:
|
def _handle_state_req(self, message: Message) -> None:
|
||||||
MsgsStateReq.from_bytes(message.body)
|
MsgsStateReq.from_bytes(message.body)
|
||||||
@ -285,9 +296,14 @@ class Encrypted(Mtp):
|
|||||||
|
|
||||||
def _handle_future_salts(self, message: Message) -> None:
|
def _handle_future_salts(self, message: Message) -> None:
|
||||||
salts = FutureSalts.from_bytes(message.body)
|
salts = FutureSalts.from_bytes(message.body)
|
||||||
self._rpc_results.append((MsgId(salts.req_msg_id), message.body))
|
|
||||||
|
|
||||||
self._start_salt_time = (salts.now, int(time.time()))
|
if salts.req_msg_id == self._salt_request_msg_id:
|
||||||
|
# Response to internal request, do not propagate.
|
||||||
|
self._salt_request_msg_id = None
|
||||||
|
else:
|
||||||
|
self._rpc_results.append((MsgId(salts.req_msg_id), message.body))
|
||||||
|
|
||||||
|
self._start_salt_time = (salts.now, self._adjusted_now())
|
||||||
self._salts = salts.salts
|
self._salts = salts.salts
|
||||||
self._salts.sort(key=lambda salt: -salt.valid_since)
|
self._salts.sort(key=lambda salt: -salt.valid_since)
|
||||||
|
|
||||||
@ -334,28 +350,38 @@ class Encrypted(Mtp):
|
|||||||
def _handle_update(self, message: Message) -> None:
|
def _handle_update(self, message: Message) -> None:
|
||||||
self._updates.append(message.body)
|
self._updates.append(message.body)
|
||||||
|
|
||||||
def push(self, request: bytes) -> Optional[MsgId]:
|
def _try_request_salts(self) -> None:
|
||||||
if not self._buffer:
|
if (
|
||||||
# Reserve space for `finalize`
|
len(self._salts) == 1
|
||||||
self._buffer += bytes(HEADER_LEN + CONTAINER_HEADER_LEN)
|
and self._salt_request_msg_id is None
|
||||||
|
and self._get_current_salt() != 0
|
||||||
|
):
|
||||||
|
# If salts are requested in a container leading to bad_msg,
|
||||||
|
# the bad_msg_id will refer to the container, not the salts request.
|
||||||
|
#
|
||||||
|
# We don't keep track of containers and content-related messages they contain for simplicity.
|
||||||
|
# This would break, because we couldn't identify the response.
|
||||||
|
#
|
||||||
|
# So salts are only requested once we have a valid salt to reduce the chances of this happening.
|
||||||
|
self._salt_request_msg_id = self._serialize_msg(
|
||||||
|
bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True
|
||||||
|
)
|
||||||
|
|
||||||
|
def push(self, request: bytes) -> Optional[MsgId]:
|
||||||
if self._pending_ack:
|
if self._pending_ack:
|
||||||
self._serialize_msg(bytes(MsgsAck(msg_ids=self._pending_ack)), False)
|
self._serialize_msg(bytes(MsgsAck(msg_ids=self._pending_ack)), False)
|
||||||
self._pending_ack = []
|
self._pending_ack = []
|
||||||
|
|
||||||
if self._start_salt_time:
|
if self._start_salt_time and len(self._salts) >= 2:
|
||||||
start_secs, start_instant = self._start_salt_time
|
start_secs, start_instant = self._start_salt_time
|
||||||
if len(self._salts) >= 2:
|
salt = self._salts[-2]
|
||||||
salt = self._salts[-2]
|
now = start_secs + (start_instant - self._adjusted_now())
|
||||||
now = start_secs + (start_instant - int(time.time()))
|
if now >= salt.valid_since + SALT_USE_DELAY:
|
||||||
if now >= salt.valid_since + SALT_USE_DELAY:
|
self._salts.pop()
|
||||||
self._salts.pop()
|
|
||||||
if len(self._salts) == 1:
|
|
||||||
self._serialize_msg(
|
|
||||||
bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._msg_count == CONTAINER_MAX_LENGTH:
|
self._try_request_salts()
|
||||||
|
|
||||||
|
if self._msg_count >= CONTAINER_MAX_LENGTH:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
assert len(request) + MESSAGE_SIZE_OVERHEAD <= CONTAINER_MAX_SIZE
|
assert len(request) + MESSAGE_SIZE_OVERHEAD <= CONTAINER_MAX_SIZE
|
||||||
|
@ -186,24 +186,17 @@ class Sender:
|
|||||||
if self._write_drain_pending:
|
if self._write_drain_pending:
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO test that the a request is only ever sent onrece
|
for request in self._requests:
|
||||||
requests = [r for r in self._requests if isinstance(r.state, NotSerialized)]
|
if isinstance(request.state, NotSerialized):
|
||||||
if not requests:
|
if (msg_id := self._mtp.push(request.body)) is not None:
|
||||||
return
|
request.state = Serialized(msg_id)
|
||||||
|
else:
|
||||||
msg_ids = []
|
break
|
||||||
for request in requests:
|
|
||||||
if (msg_id := self._mtp.push(request.body)) is not None:
|
|
||||||
msg_ids.append(msg_id)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
mtp_buffer = self._mtp.finalize()
|
mtp_buffer = self._mtp.finalize()
|
||||||
self._transport.pack(mtp_buffer, self._writer.write)
|
if mtp_buffer:
|
||||||
self._write_drain_pending = True
|
self._transport.pack(mtp_buffer, self._writer.write)
|
||||||
|
self._write_drain_pending = True
|
||||||
for req, msg_id in zip(requests, msg_ids):
|
|
||||||
req.state = Serialized(msg_id)
|
|
||||||
|
|
||||||
def _on_net_read(self, read_buffer: bytes) -> List[Updates]:
|
def _on_net_read(self, read_buffer: bytes) -> List[Updates]:
|
||||||
if not read_buffer:
|
if not read_buffer:
|
||||||
@ -255,34 +248,47 @@ class Sender:
|
|||||||
updates.append(u)
|
updates.append(u)
|
||||||
|
|
||||||
for msg_id, ret in result.rpc_results:
|
for msg_id, ret in result.rpc_results:
|
||||||
found = False
|
for i, req in enumerate(self._requests):
|
||||||
for i in reversed(range(len(self._requests))):
|
|
||||||
req = self._requests[i]
|
|
||||||
if isinstance(req.state, Serialized) and req.state.msg_id == msg_id:
|
if isinstance(req.state, Serialized) and req.state.msg_id == msg_id:
|
||||||
raise RuntimeError("got rpc result for unsent request")
|
raise RuntimeError("got rpc result for unsent request")
|
||||||
if isinstance(req.state, Sent) and req.state.msg_id == msg_id:
|
elif isinstance(req.state, Sent) and req.state.msg_id == msg_id:
|
||||||
found = True
|
del self._requests[i]
|
||||||
if isinstance(ret, bytes):
|
|
||||||
assert len(ret) >= 4
|
|
||||||
elif isinstance(ret, RpcError):
|
|
||||||
ret._caused_by = struct.unpack_from("<I", req.body)[0]
|
|
||||||
raise ret
|
|
||||||
elif isinstance(ret, BadMessage):
|
|
||||||
# TODO test that we resend the request
|
|
||||||
req.state = NotSerialized()
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise RuntimeError("unexpected case")
|
|
||||||
|
|
||||||
req = self._requests.pop(i)
|
|
||||||
req.result.set_result(ret)
|
|
||||||
break
|
break
|
||||||
if not found:
|
else:
|
||||||
self._logger.warning(
|
self._logger.warning(
|
||||||
"telegram sent rpc_result for unknown msg_id=%d: %s",
|
"telegram sent rpc_result for unknown msg_id=%d: %s",
|
||||||
msg_id,
|
msg_id,
|
||||||
ret.hex() if isinstance(ret, bytes) else repr(ret),
|
ret.hex() if isinstance(ret, bytes) else repr(ret),
|
||||||
)
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(ret, bytes):
|
||||||
|
assert len(ret) >= 4
|
||||||
|
req.result.set_result(ret)
|
||||||
|
elif isinstance(ret, RpcError):
|
||||||
|
ret._caused_by = struct.unpack_from("<I", req.body)[0]
|
||||||
|
req.result.set_exception(ret)
|
||||||
|
elif isinstance(ret, BadMessage):
|
||||||
|
if ret.retryable:
|
||||||
|
self._logger.log(
|
||||||
|
ret.severity,
|
||||||
|
"telegram notified of bad msg_id=%d; will attempt to resend request: %s",
|
||||||
|
msg_id,
|
||||||
|
ret,
|
||||||
|
)
|
||||||
|
req.state = NotSerialized()
|
||||||
|
self._requests.append(req)
|
||||||
|
else:
|
||||||
|
self._logger.log(
|
||||||
|
ret.severity,
|
||||||
|
"telegram notified of bad msg_id=%d; impossible to retry: %s",
|
||||||
|
msg_id,
|
||||||
|
ret,
|
||||||
|
)
|
||||||
|
ret._caused_by = struct.unpack_from("<I", req.body)[0]
|
||||||
|
req.result.set_exception(ret)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("unexpected case")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def auth_key(self) -> Optional[bytes]:
|
def auth_key(self) -> Optional[bytes]:
|
||||||
|
@ -66,6 +66,7 @@ class PackedChat:
|
|||||||
"""
|
"""
|
||||||
return bytes(self).hex()
|
return bytes(self).hex()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def from_hex(cls, hex: str) -> Self:
|
def from_hex(cls, hex: str) -> Self:
|
||||||
"""
|
"""
|
||||||
Convenience method to convert hexadecimal numbers into bytes then passed to :meth:`from_bytes`:
|
Convenience method to convert hexadecimal numbers into bytes then passed to :meth:`from_bytes`:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ..tl.core.serializable import obj_repr
|
from ..tl.core.serializable import obj_repr
|
||||||
|
|
||||||
@ -13,7 +13,7 @@ class DataCenter:
|
|||||||
:param auth: See below.
|
:param auth: See below.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ("id", "ipv4_addr", "ipv6_addr", "auth")
|
__slots__: Tuple[str, ...] = ("id", "ipv4_addr", "ipv6_addr", "auth")
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user