mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-06-17 10:36:37 +00:00
Merge 3bfa64a5d6
into f852e83363
This commit is contained in:
commit
0eaf2f5fac
@ -10,6 +10,7 @@ from typing_extensions import Self
|
|||||||
|
|
||||||
from ....version import __version__ as default_version
|
from ....version import __version__ as default_version
|
||||||
from ...mtsender import Connector, Sender
|
from ...mtsender import Connector, Sender
|
||||||
|
from ...mtsender.reconnection import NoReconnect, ReconnectionPolicy
|
||||||
from ...session import (
|
from ...session import (
|
||||||
ChannelRef,
|
ChannelRef,
|
||||||
ChatHashCache,
|
ChatHashCache,
|
||||||
@ -215,6 +216,7 @@ class Client:
|
|||||||
lang_code: Optional[str] = None,
|
lang_code: Optional[str] = None,
|
||||||
datacenter: Optional[DataCenter] = None,
|
datacenter: Optional[DataCenter] = None,
|
||||||
connector: Optional[Connector] = None,
|
connector: Optional[Connector] = None,
|
||||||
|
reconnection_policy: Optional[ReconnectionPolicy] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert __package__
|
assert __package__
|
||||||
base_logger = logger or logging.getLogger(__package__[: __package__.index(".")])
|
base_logger = logger or logging.getLogger(__package__[: __package__.index(".")])
|
||||||
@ -246,6 +248,7 @@ class Client:
|
|||||||
update_queue_limit=update_queue_limit,
|
update_queue_limit=update_queue_limit,
|
||||||
base_logger=base_logger,
|
base_logger=base_logger,
|
||||||
connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)),
|
connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)),
|
||||||
|
reconnection_policy=reconnection_policy or NoReconnect(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._session = Session()
|
self._session = Session()
|
||||||
@ -253,9 +256,9 @@ class Client:
|
|||||||
self._message_box = MessageBox(base_logger=base_logger)
|
self._message_box = MessageBox(base_logger=base_logger)
|
||||||
self._chat_hashes = ChatHashCache(None)
|
self._chat_hashes = ChatHashCache(None)
|
||||||
self._last_update_limit_warn: Optional[float] = None
|
self._last_update_limit_warn: Optional[float] = None
|
||||||
self._updates: asyncio.Queue[
|
self._updates: asyncio.Queue[tuple[abcs.Update, dict[int, Peer]]] = (
|
||||||
tuple[abcs.Update, dict[int, Peer]]
|
asyncio.Queue(maxsize=self._config.update_queue_limit or 0)
|
||||||
] = asyncio.Queue(maxsize=self._config.update_queue_limit or 0)
|
)
|
||||||
self._dispatcher: Optional[asyncio.Task[None]] = None
|
self._dispatcher: Optional[asyncio.Task[None]] = None
|
||||||
self._handlers: dict[
|
self._handlers: dict[
|
||||||
Type[Event],
|
Type[Event],
|
||||||
|
@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Optional, TypeVar
|
|||||||
|
|
||||||
from ....version import __version__
|
from ....version import __version__
|
||||||
from ...mtproto import BadStatusError, Full, RpcError
|
from ...mtproto import BadStatusError, Full, RpcError
|
||||||
from ...mtsender import Connector, Sender
|
from ...mtsender import Connector, ReconnectionPolicy, Sender
|
||||||
from ...mtsender import connect as do_connect_sender
|
from ...mtsender import connect as do_connect_sender
|
||||||
from ...session import DataCenter
|
from ...session import DataCenter
|
||||||
from ...session import User as SessionUser
|
from ...session import User as SessionUser
|
||||||
@ -46,6 +46,7 @@ class Config:
|
|||||||
api_hash: str
|
api_hash: str
|
||||||
base_logger: logging.Logger
|
base_logger: logging.Logger
|
||||||
connector: Connector
|
connector: Connector
|
||||||
|
reconnection_policy: ReconnectionPolicy
|
||||||
device_model: str = field(default_factory=default_device_model)
|
device_model: str = field(default_factory=default_device_model)
|
||||||
system_version: str = field(default_factory=default_system_version)
|
system_version: str = field(default_factory=default_system_version)
|
||||||
app_version: str = __version__
|
app_version: str = __version__
|
||||||
@ -100,6 +101,7 @@ async def connect_sender(
|
|||||||
auth_key=auth,
|
auth_key=auth,
|
||||||
base_logger=config.base_logger,
|
base_logger=config.base_logger,
|
||||||
connector=config.connector,
|
connector=config.connector,
|
||||||
|
reconnection_policy=config.reconnection_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -7,6 +7,24 @@ from .aes import ige_decrypt, ige_encrypt
|
|||||||
from .auth_key import AuthKey
|
from .auth_key import AuthKey
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidBufferError(ValueError):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("invalid ciphertext buffer length")
|
||||||
|
|
||||||
|
|
||||||
|
class AuthKeyMismatchError(ValueError):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("server authkey mismatches with ours")
|
||||||
|
|
||||||
|
|
||||||
|
class MsgKeyMismatchError(ValueError):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("server msgkey mismatches with ours")
|
||||||
|
|
||||||
|
|
||||||
|
CryptoError = InvalidBufferError | AuthKeyMismatchError | MsgKeyMismatchError
|
||||||
|
|
||||||
|
|
||||||
# "where x = 0 for messages from client to server and x = 8 for those from server to client"
|
# "where x = 0 for messages from client to server and x = 8 for those from server to client"
|
||||||
class Side(IntEnum):
|
class Side(IntEnum):
|
||||||
CLIENT = 0
|
CLIENT = 0
|
||||||
@ -77,14 +95,14 @@ def decrypt_data_v2(
|
|||||||
x = int(side)
|
x = int(side)
|
||||||
|
|
||||||
if len(ciphertext) < 24 or (len(ciphertext) - 24) % 16 != 0:
|
if len(ciphertext) < 24 or (len(ciphertext) - 24) % 16 != 0:
|
||||||
raise ValueError("invalid ciphertext buffer length")
|
raise InvalidBufferError()
|
||||||
|
|
||||||
# salt, session_id and sequence_number should also be checked.
|
# salt, session_id and sequence_number should also be checked.
|
||||||
# However, not doing so has worked fine for years.
|
# However, not doing so has worked fine for years.
|
||||||
|
|
||||||
key_id = ciphertext[:8]
|
key_id = ciphertext[:8]
|
||||||
if auth_key.key_id != key_id:
|
if auth_key.key_id != key_id:
|
||||||
raise ValueError("server authkey mismatches with ours")
|
raise AuthKeyMismatchError()
|
||||||
|
|
||||||
msg_key = ciphertext[8:24]
|
msg_key = ciphertext[8:24]
|
||||||
key, iv = calc_key(auth_key, msg_key, side)
|
key, iv = calc_key(auth_key, msg_key, side)
|
||||||
@ -93,7 +111,7 @@ def decrypt_data_v2(
|
|||||||
# https://core.telegram.org/mtproto/security_guidelines#mtproto-encrypted-messages
|
# https://core.telegram.org/mtproto/security_guidelines#mtproto-encrypted-messages
|
||||||
our_key = sha256(auth_key.data[x + 88 : x + 120] + plaintext).digest()
|
our_key = sha256(auth_key.data[x + 88 : x + 120] + plaintext).digest()
|
||||||
if msg_key != our_key[8:24]:
|
if msg_key != our_key[8:24]:
|
||||||
raise ValueError("server msgkey mismatches with ours")
|
raise MsgKeyMismatchError()
|
||||||
|
|
||||||
return plaintext
|
return plaintext
|
||||||
|
|
||||||
|
@ -62,11 +62,15 @@ from ..utils import (
|
|||||||
)
|
)
|
||||||
from .types import (
|
from .types import (
|
||||||
BadMessageError,
|
BadMessageError,
|
||||||
|
DecompressionFailed,
|
||||||
Deserialization,
|
Deserialization,
|
||||||
|
DeserializationFailure,
|
||||||
|
MsgBufferTooSmall,
|
||||||
MsgId,
|
MsgId,
|
||||||
Mtp,
|
Mtp,
|
||||||
RpcError,
|
RpcError,
|
||||||
RpcResult,
|
RpcResult,
|
||||||
|
UnexpectedConstructor,
|
||||||
Update,
|
Update,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -85,6 +89,7 @@ UPDATE_IDS = {
|
|||||||
AffectedFoundMessages.constructor_id(),
|
AffectedFoundMessages.constructor_id(),
|
||||||
AffectedHistory.constructor_id(),
|
AffectedHistory.constructor_id(),
|
||||||
AffectedMessages.constructor_id(),
|
AffectedMessages.constructor_id(),
|
||||||
|
# TODO InvitedUsers
|
||||||
}
|
}
|
||||||
|
|
||||||
HEADER_LEN = 8 + 8 # salt, client_id
|
HEADER_LEN = 8 + 8 # salt, client_id
|
||||||
@ -151,7 +156,7 @@ class Encrypted(Mtp):
|
|||||||
self._last_msg_id: int
|
self._last_msg_id: int
|
||||||
self._in_pending_ack: list[int] = []
|
self._in_pending_ack: list[int] = []
|
||||||
self._msg_count: int
|
self._msg_count: int
|
||||||
self._reset_session()
|
self.reset()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def auth_key(self) -> bytes:
|
def auth_key(self) -> bytes:
|
||||||
@ -166,13 +171,6 @@ class Encrypted(Mtp):
|
|||||||
def _adjusted_now(self) -> float:
|
def _adjusted_now(self) -> float:
|
||||||
return time.time() + self._time_offset
|
return time.time() + self._time_offset
|
||||||
|
|
||||||
def _reset_session(self) -> None:
|
|
||||||
self._client_id = struct.unpack("<q", os.urandom(8))[0]
|
|
||||||
self._sequence = 0
|
|
||||||
self._last_msg_id = 0
|
|
||||||
self._in_pending_ack.clear()
|
|
||||||
self._msg_count = 0
|
|
||||||
|
|
||||||
def _get_new_msg_id(self) -> int:
|
def _get_new_msg_id(self) -> int:
|
||||||
new_msg_id = int(self._adjusted_now() * 0x100000000)
|
new_msg_id = int(self._adjusted_now() * 0x100000000)
|
||||||
if self._last_msg_id >= new_msg_id:
|
if self._last_msg_id >= new_msg_id:
|
||||||
@ -245,12 +243,38 @@ class Encrypted(Mtp):
|
|||||||
result = rpc_result.result
|
result = rpc_result.result
|
||||||
|
|
||||||
msg_id = MsgId(req_msg_id)
|
msg_id = MsgId(req_msg_id)
|
||||||
inner_constructor = struct.unpack_from("<I", result)[0]
|
|
||||||
|
try:
|
||||||
|
inner_constructor = struct.unpack_from("<I", result)[0]
|
||||||
|
except struct.error:
|
||||||
|
# If the result is empty, we can't unpack it.
|
||||||
|
# This can happen if the server returns an empty response.
|
||||||
|
logging.exception("failed to unpack inner_constructor")
|
||||||
|
self._deserialization.append(
|
||||||
|
DeserializationFailure(
|
||||||
|
msg_id=msg_id,
|
||||||
|
error=MsgBufferTooSmall(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
if inner_constructor == GeneratedRpcError.constructor_id():
|
if inner_constructor == GeneratedRpcError.constructor_id():
|
||||||
error = RpcError._from_mtproto_error(GeneratedRpcError.from_bytes(result))
|
try:
|
||||||
error.msg_id = msg_id
|
error = RpcError._from_mtproto_error(
|
||||||
self._deserialization.append(error)
|
GeneratedRpcError.from_bytes(result)
|
||||||
|
)
|
||||||
|
error.msg_id = msg_id
|
||||||
|
self._deserialization.append(error)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("failed to deserialize error")
|
||||||
|
self._deserialization.append(
|
||||||
|
DeserializationFailure(
|
||||||
|
msg_id=msg_id,
|
||||||
|
error=UnexpectedConstructor(
|
||||||
|
id=inner_constructor,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
elif inner_constructor == RpcAnswerUnknown.constructor_id():
|
elif inner_constructor == RpcAnswerUnknown.constructor_id():
|
||||||
pass # msg_id = rpc_drop_answer.msg_id
|
pass # msg_id = rpc_drop_answer.msg_id
|
||||||
elif inner_constructor == RpcAnswerDroppedRunning.constructor_id():
|
elif inner_constructor == RpcAnswerDroppedRunning.constructor_id():
|
||||||
@ -258,9 +282,15 @@ class Encrypted(Mtp):
|
|||||||
elif inner_constructor == RpcAnswerDropped.constructor_id():
|
elif inner_constructor == RpcAnswerDropped.constructor_id():
|
||||||
pass # dropped
|
pass # dropped
|
||||||
elif inner_constructor == GzipPacked.constructor_id():
|
elif inner_constructor == GzipPacked.constructor_id():
|
||||||
body = gzip_decompress(GzipPacked.from_bytes(result))
|
try:
|
||||||
self._store_own_updates(body)
|
body = gzip_decompress(GzipPacked.from_bytes(result))
|
||||||
self._deserialization.append(RpcResult(msg_id, body))
|
self._store_own_updates(body)
|
||||||
|
self._deserialization.append(RpcResult(msg_id, body))
|
||||||
|
except Exception:
|
||||||
|
logging.exception("failed to decompress response")
|
||||||
|
self._deserialization.append(
|
||||||
|
DeserializationFailure(msg_id=msg_id, error=DecompressionFailed())
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self._store_own_updates(result)
|
self._store_own_updates(result)
|
||||||
self._deserialization.append(RpcResult(msg_id, result))
|
self._deserialization.append(RpcResult(msg_id, result))
|
||||||
@ -300,7 +330,7 @@ class Encrypted(Mtp):
|
|||||||
elif bad_msg.error_code in (16, 17):
|
elif bad_msg.error_code in (16, 17):
|
||||||
self._correct_time_offset(message.msg_id)
|
self._correct_time_offset(message.msg_id)
|
||||||
elif bad_msg.error_code in (32, 33):
|
elif bad_msg.error_code in (32, 33):
|
||||||
self._reset_session()
|
self.reset()
|
||||||
else:
|
else:
|
||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
@ -365,6 +395,9 @@ class Encrypted(Mtp):
|
|||||||
for inner_message in container.messages:
|
for inner_message in container.messages:
|
||||||
self._process_message(inner_message)
|
self._process_message(inner_message)
|
||||||
|
|
||||||
|
def _handle_msg_copy(self, message: Message) -> None:
|
||||||
|
raise RuntimeError("msg_copy should not be used")
|
||||||
|
|
||||||
def _handle_gzip_packed(self, message: Message) -> None:
|
def _handle_gzip_packed(self, message: Message) -> None:
|
||||||
container = GzipPacked.from_bytes(message.body)
|
container = GzipPacked.from_bytes(message.body)
|
||||||
inner_body = gzip_decompress(container)
|
inner_body = gzip_decompress(container)
|
||||||
@ -459,3 +492,11 @@ class Encrypted(Mtp):
|
|||||||
result = self._deserialization[:]
|
result = self._deserialization[:]
|
||||||
self._deserialization.clear()
|
self._deserialization.clear()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self._client_id = struct.unpack("<q", os.urandom(8))[0]
|
||||||
|
self._sequence = 0
|
||||||
|
self._last_msg_id = 0
|
||||||
|
self._in_pending_ack.clear()
|
||||||
|
self._msg_count = 0
|
||||||
|
self._salt_request_msg_id = None
|
||||||
|
@ -2,7 +2,16 @@ import struct
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..utils import check_message_buffer
|
from ..utils import check_message_buffer
|
||||||
from .types import Deserialization, MsgId, Mtp, RpcResult
|
from .types import (
|
||||||
|
BadAuthKeyError,
|
||||||
|
BadMsgIdError,
|
||||||
|
Deserialization,
|
||||||
|
MsgId,
|
||||||
|
Mtp,
|
||||||
|
NegativeLengthError,
|
||||||
|
RpcResult,
|
||||||
|
TooLongMsgError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Plain(Mtp):
|
class Plain(Mtp):
|
||||||
@ -38,18 +47,19 @@ class Plain(Mtp):
|
|||||||
|
|
||||||
auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload)
|
auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload)
|
||||||
if auth_key_id != 0:
|
if auth_key_id != 0:
|
||||||
raise ValueError(f"bad auth key, expected: 0, got: {auth_key_id}")
|
raise BadAuthKeyError(got=auth_key_id, expected=0)
|
||||||
|
|
||||||
# https://core.telegram.org/mtproto/description#message-identifier-msg-id
|
# https://core.telegram.org/mtproto/description#message-identifier-msg-id
|
||||||
if msg_id <= 0 or (msg_id % 4) != 1:
|
if msg_id <= 0 or (msg_id % 4) != 1:
|
||||||
raise ValueError(f"bad msg id, got: {msg_id}")
|
raise BadMsgIdError(got=msg_id)
|
||||||
|
|
||||||
if length < 0:
|
if length < 0:
|
||||||
raise ValueError(f"bad length: expected >= 0, got: {length}")
|
raise NegativeLengthError(got=length)
|
||||||
|
|
||||||
if 20 + length > len(payload):
|
if 20 + length > (lp := len(payload)):
|
||||||
raise ValueError(
|
raise TooLongMsgError(got=length, max_length=lp - 20)
|
||||||
f"message too short, expected: {20 + length}, got {len(payload)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))]
|
return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))]
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self._buffer.clear()
|
||||||
|
@ -5,6 +5,7 @@ from typing import NewType, Optional
|
|||||||
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from ...crypto.crypto import CryptoError
|
||||||
from ...tl.mtproto.types import RpcError as GeneratedRpcError
|
from ...tl.mtproto.types import RpcError as GeneratedRpcError
|
||||||
|
|
||||||
MsgId = NewType("MsgId", int)
|
MsgId = NewType("MsgId", int)
|
||||||
@ -180,7 +181,112 @@ class BadMessageError(ValueError):
|
|||||||
return self._code == other._code
|
return self._code == other._code
|
||||||
|
|
||||||
|
|
||||||
Deserialization = Update | RpcResult | RpcError | BadMessageError
|
# Deserialization errors are not fatal, so we don't subclass RpcError.
|
||||||
|
class BadAuthKeyError(ValueError):
|
||||||
|
def __init__(self, *args: object, got: int, expected: int) -> None:
|
||||||
|
super().__init__(f"bad server auth key (got {got}, expected {expected})", *args)
|
||||||
|
self._got = got
|
||||||
|
self._expected = expected
|
||||||
|
|
||||||
|
@property
|
||||||
|
def got(self):
|
||||||
|
return self._got
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected(self):
|
||||||
|
return self._expected
|
||||||
|
|
||||||
|
|
||||||
|
class BadMsgIdError(ValueError):
|
||||||
|
def __init__(self, *args: object, got: int) -> None:
|
||||||
|
super().__init__(f"bad server message id (got {got})", *args)
|
||||||
|
self._got = got
|
||||||
|
|
||||||
|
@property
|
||||||
|
def got(self):
|
||||||
|
return self._got
|
||||||
|
|
||||||
|
|
||||||
|
class NegativeLengthError(ValueError):
|
||||||
|
def __init__(self, *args: object, got: int) -> None:
|
||||||
|
super().__init__(f"bad server message length (got {got})", *args)
|
||||||
|
self._got = got
|
||||||
|
|
||||||
|
@property
|
||||||
|
def got(self):
|
||||||
|
return self._got
|
||||||
|
|
||||||
|
|
||||||
|
class TooLongMsgError(ValueError):
|
||||||
|
def __init__(self, *args: object, got: int, max_length: int) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"bad server message length (got {got}, when at most it should be {max_length})",
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
self._got = got
|
||||||
|
self._expected = max_length
|
||||||
|
|
||||||
|
@property
|
||||||
|
def got(self):
|
||||||
|
return self._got
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected(self):
|
||||||
|
return self._expected
|
||||||
|
|
||||||
|
|
||||||
|
class MsgBufferTooSmall(ValueError):
|
||||||
|
def __init__(self, *args: object) -> None:
|
||||||
|
super().__init__(
|
||||||
|
"server responded with a payload that's too small to fit a valid message",
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DecompressionFailed(ValueError):
|
||||||
|
def __init__(self, *args: object) -> None:
|
||||||
|
super().__init__("failed to decompress server's data", *args)
|
||||||
|
|
||||||
|
|
||||||
|
class UnexpectedConstructor(ValueError):
|
||||||
|
def __init__(self, *args: object, id: int) -> None:
|
||||||
|
super().__init__(f"unexpected constructor: {id:08x}", *args)
|
||||||
|
|
||||||
|
|
||||||
|
class DecryptionError(ValueError):
|
||||||
|
def __init__(self, *args: object, error: CryptoError) -> None:
|
||||||
|
super().__init__(f"failed to decrypt message: {error}", *args)
|
||||||
|
|
||||||
|
self._error = error
|
||||||
|
|
||||||
|
@property
|
||||||
|
def error(self):
|
||||||
|
return self._error
|
||||||
|
|
||||||
|
|
||||||
|
DeserializationError = (
|
||||||
|
BadAuthKeyError
|
||||||
|
| BadMsgIdError
|
||||||
|
| NegativeLengthError
|
||||||
|
| TooLongMsgError
|
||||||
|
| MsgBufferTooSmall
|
||||||
|
| DecompressionFailed
|
||||||
|
| UnexpectedConstructor
|
||||||
|
| DecryptionError
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeserializationFailure:
|
||||||
|
__slots__ = ("msg_id", "error")
|
||||||
|
|
||||||
|
def __init__(self, msg_id: MsgId, error: DeserializationError) -> None:
|
||||||
|
self.msg_id = msg_id
|
||||||
|
self.error = error
|
||||||
|
|
||||||
|
|
||||||
|
Deserialization = (
|
||||||
|
Update | RpcResult | RpcError | BadMessageError | DeserializationFailure
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# https://core.telegram.org/mtproto/description
|
# https://core.telegram.org/mtproto/description
|
||||||
@ -209,3 +315,9 @@ class Mtp(ABC):
|
|||||||
"""
|
"""
|
||||||
Deserialize incoming buffer payload.
|
Deserialize incoming buffer payload.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""
|
||||||
|
Reset the internal buffer.
|
||||||
|
"""
|
||||||
|
@ -15,13 +15,37 @@ class Transport(ABC):
|
|||||||
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
|
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MissingBytesError(ValueError):
|
class MissingBytesError(ValueError):
|
||||||
def __init__(self, *, expected: int, got: int) -> None:
|
def __init__(self, *, expected: int, got: int) -> None:
|
||||||
super().__init__(f"missing bytes, expected: {expected}, got: {got}")
|
super().__init__(f"missing bytes, expected: {expected}, got: {got}")
|
||||||
|
|
||||||
|
|
||||||
|
class BadLenError(ValueError):
|
||||||
|
def __init__(self, *, got: int) -> None:
|
||||||
|
super().__init__(f"bad len (got {got})")
|
||||||
|
|
||||||
|
|
||||||
|
class BadSeqError(ValueError):
|
||||||
|
def __init__(self, *, expected: int, got: int) -> None:
|
||||||
|
super().__init__(f"bad seq (expected {expected}, got {got})")
|
||||||
|
|
||||||
|
|
||||||
|
class BadCrcError(ValueError):
|
||||||
|
def __init__(self, *, expected: int, got: int) -> None:
|
||||||
|
super().__init__(f"bad crc (expected {expected}, got {got})")
|
||||||
|
|
||||||
|
|
||||||
class BadStatusError(ValueError):
|
class BadStatusError(ValueError):
|
||||||
def __init__(self, *, status: int) -> None:
|
def __init__(self, *, status: int) -> None:
|
||||||
super().__init__(f"transport reported bad status: {status}")
|
super().__init__(f"bad status (negative length -{status})")
|
||||||
self.status = status
|
self.status = status
|
||||||
|
|
||||||
|
|
||||||
|
TransportError = (
|
||||||
|
MissingBytesError | BadLenError | BadSeqError | BadCrcError | BadStatusError
|
||||||
|
)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
from .abcs import BadStatusError, MissingBytesError, OutFn, Transport
|
from .abcs import BadStatusError, MissingBytesError, OutFn, Transport
|
||||||
@ -60,3 +61,7 @@ class Abridged(Transport):
|
|||||||
|
|
||||||
output += memoryview(input)[header_len : header_len + length]
|
output += memoryview(input)[header_len : header_len + length]
|
||||||
return header_len + length
|
return header_len + length
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
logging.info("resetting sending of header in abridged transport")
|
||||||
|
self._init = False
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import struct
|
import struct
|
||||||
from zlib import crc32
|
from zlib import crc32
|
||||||
|
|
||||||
@ -61,3 +62,8 @@ class Full(Transport):
|
|||||||
self._recv_seq += 1
|
self._recv_seq += 1
|
||||||
output += memoryview(input)[8 : length - 4]
|
output += memoryview(input)[8 : length - 4]
|
||||||
return length
|
return length
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
logging.info("resetting recv and send seqs in full transport")
|
||||||
|
self._send_seq = 0
|
||||||
|
self._recv_seq = 0
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
from .abcs import BadStatusError, MissingBytesError, OutFn, Transport
|
from .abcs import BadStatusError, MissingBytesError, OutFn, Transport
|
||||||
@ -52,3 +53,7 @@ class Intermediate(Transport):
|
|||||||
|
|
||||||
output += memoryview(input)[4 : 4 + length]
|
output += memoryview(input)[4 : 4 + length]
|
||||||
return length + 4
|
return length + 4
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
logging.info("resetting sending of header in intermediate transport")
|
||||||
|
self._init = False
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from .reconnection import ReconnectionPolicy
|
||||||
from .sender import (
|
from .sender import (
|
||||||
MAXIMUM_DATA,
|
MAXIMUM_DATA,
|
||||||
NO_PING_DISCONNECT,
|
NO_PING_DISCONNECT,
|
||||||
@ -18,4 +19,5 @@ __all__ = [
|
|||||||
"Connector",
|
"Connector",
|
||||||
"Sender",
|
"Sender",
|
||||||
"connect",
|
"connect",
|
||||||
|
"ReconnectionPolicy",
|
||||||
]
|
]
|
||||||
|
6
client/src/telethon/_impl/mtsender/errors.py
Normal file
6
client/src/telethon/_impl/mtsender/errors.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from struct import error as struct_error
|
||||||
|
|
||||||
|
from ..mtproto.mtp.types import DeserializationError
|
||||||
|
from ..mtproto.transport.abcs import TransportError
|
||||||
|
|
||||||
|
ReadError = struct_error | TransportError | DeserializationError
|
36
client/src/telethon/_impl/mtsender/reconnection.py
Normal file
36
client/src/telethon/_impl/mtsender/reconnection.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class ReconnectionPolicy(ABC):
|
||||||
|
"""
|
||||||
|
Base class for reconnection policies.
|
||||||
|
|
||||||
|
This class defines the interface for reconnection policies used by the MTSender.
|
||||||
|
It allows for custom reconnection strategies to be implemented by subclasses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def should_retry(self, attempts: int) -> bool | float:
|
||||||
|
"""
|
||||||
|
Determines whether the client should retry the connection attempt.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NoReconnect(ReconnectionPolicy):
|
||||||
|
def should_retry(self, attempts: int) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class FixedReconnect(ReconnectionPolicy):
|
||||||
|
__slots__ = ("max_attempts", "delay")
|
||||||
|
|
||||||
|
def __init__(self, attempts: int, delay: float):
|
||||||
|
self.max_attempts = attempts
|
||||||
|
self.delay = delay
|
||||||
|
|
||||||
|
def should_retry(self, attempts: int) -> bool | float:
|
||||||
|
if attempts < self.max_attempts:
|
||||||
|
return self.delay
|
||||||
|
|
||||||
|
return False
|
@ -24,12 +24,14 @@ from ..mtproto import (
|
|||||||
Update,
|
Update,
|
||||||
authentication,
|
authentication,
|
||||||
)
|
)
|
||||||
|
from ..mtproto.mtp.types import DeserializationFailure
|
||||||
from ..tl import Request as RemoteCall
|
from ..tl import Request as RemoteCall
|
||||||
from ..tl.abcs import Updates
|
from ..tl.abcs import Updates
|
||||||
from ..tl.core import Serializable
|
from ..tl.core import Serializable
|
||||||
from ..tl.mtproto.functions import ping_delay_disconnect
|
from ..tl.mtproto.functions import ping_delay_disconnect
|
||||||
from ..tl.types import UpdateDeleteMessages, UpdateShort
|
from ..tl.types import UpdateDeleteMessages, UpdateShort, UpdatesTooLong
|
||||||
from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages
|
from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages
|
||||||
|
from .reconnection import ReconnectionPolicy
|
||||||
|
|
||||||
MAXIMUM_DATA = (1024 * 1024) + (8 * 1024)
|
MAXIMUM_DATA = (1024 * 1024) + (8 * 1024)
|
||||||
|
|
||||||
@ -162,6 +164,8 @@ class Request(Generic[Return]):
|
|||||||
class Sender:
|
class Sender:
|
||||||
dc_id: int
|
dc_id: int
|
||||||
addr: str
|
addr: str
|
||||||
|
_connector: Connector
|
||||||
|
_reconnection_policy: ReconnectionPolicy
|
||||||
_logger: logging.Logger
|
_logger: logging.Logger
|
||||||
_reader: AsyncReader
|
_reader: AsyncReader
|
||||||
_writer: AsyncWriter
|
_writer: AsyncWriter
|
||||||
@ -175,6 +179,7 @@ class Sender:
|
|||||||
_requests: list[Request[object]]
|
_requests: list[Request[object]]
|
||||||
_next_ping: float
|
_next_ping: float
|
||||||
_read_buffer: bytearray
|
_read_buffer: bytearray
|
||||||
|
_write_drain_pending: bool
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def connect(
|
async def connect(
|
||||||
@ -185,6 +190,7 @@ class Sender:
|
|||||||
addr: str,
|
addr: str,
|
||||||
*,
|
*,
|
||||||
connector: Connector,
|
connector: Connector,
|
||||||
|
reconnection_policy: ReconnectionPolicy,
|
||||||
base_logger: logging.Logger,
|
base_logger: logging.Logger,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
ip, port = addr.split(":")
|
ip, port = addr.split(":")
|
||||||
@ -193,6 +199,8 @@ class Sender:
|
|||||||
return cls(
|
return cls(
|
||||||
dc_id=dc_id,
|
dc_id=dc_id,
|
||||||
addr=addr,
|
addr=addr,
|
||||||
|
_connector=connector,
|
||||||
|
_reconnection_policy=reconnection_policy,
|
||||||
_logger=base_logger.getChild("mtsender"),
|
_logger=base_logger.getChild("mtsender"),
|
||||||
_reader=reader,
|
_reader=reader,
|
||||||
_writer=writer,
|
_writer=writer,
|
||||||
@ -206,6 +214,7 @@ class Sender:
|
|||||||
_requests=[],
|
_requests=[],
|
||||||
_next_ping=asyncio.get_running_loop().time() + PING_DELAY,
|
_next_ping=asyncio.get_running_loop().time() + PING_DELAY,
|
||||||
_read_buffer=bytearray(),
|
_read_buffer=bytearray(),
|
||||||
|
_write_drain_pending=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def disconnect(self) -> None:
|
async def disconnect(self) -> None:
|
||||||
@ -236,15 +245,21 @@ class Sender:
|
|||||||
if rx.done():
|
if rx.done():
|
||||||
return rx.result()
|
return rx.result()
|
||||||
|
|
||||||
async def step(self) -> None:
|
async def step(self):
|
||||||
|
try:
|
||||||
|
await self._step()
|
||||||
|
except Exception as error:
|
||||||
|
await self._on_error(error)
|
||||||
|
|
||||||
|
async def _step(self) -> None:
|
||||||
if not self._writing:
|
if not self._writing:
|
||||||
self._writing = True
|
self._writing = True
|
||||||
await self._do_write()
|
await self._do_send()
|
||||||
self._writing = False
|
self._writing = False
|
||||||
|
|
||||||
if not self._reading:
|
if not self._reading:
|
||||||
self._reading = True
|
self._reading = True
|
||||||
await self._do_read()
|
await self._do_recv()
|
||||||
self._reading = False
|
self._reading = False
|
||||||
else:
|
else:
|
||||||
await self._step_done.wait()
|
await self._step_done.wait()
|
||||||
@ -254,7 +269,7 @@ class Sender:
|
|||||||
self._updates.clear()
|
self._updates.clear()
|
||||||
return updates
|
return updates
|
||||||
|
|
||||||
async def _do_read(self) -> None:
|
async def _do_recv(self) -> None:
|
||||||
self._step_done.clear()
|
self._step_done.clear()
|
||||||
|
|
||||||
timeout = self._next_ping - asyncio.get_running_loop().time()
|
timeout = self._next_ping - asyncio.get_running_loop().time()
|
||||||
@ -266,10 +281,46 @@ class Sender:
|
|||||||
else:
|
else:
|
||||||
self._on_net_read(recv_data)
|
self._on_net_read(recv_data)
|
||||||
finally:
|
finally:
|
||||||
self._try_timeout_ping()
|
self._try_ping_timeout()
|
||||||
self._step_done.set()
|
self._step_done.set()
|
||||||
|
|
||||||
async def _do_write(self) -> None:
|
async def _do_send(self) -> None:
|
||||||
|
self._try_fill_write()
|
||||||
|
|
||||||
|
if self._write_drain_pending:
|
||||||
|
await self._writer.drain()
|
||||||
|
self._on_net_write()
|
||||||
|
|
||||||
|
async def _try_connect(self):
|
||||||
|
attempts = 0
|
||||||
|
|
||||||
|
ip, port = self.addr.split(":")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
self._reader, self._writer = await self._connector(ip, int(port))
|
||||||
|
self._logger.info(
|
||||||
|
f"auto-reconnect success after {attempts} failed attempt(s)"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
attempts += 1
|
||||||
|
self._logger.warning(f"auto-reconnect failed {attempts} time(s): {e!r}")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
delay = self._reconnection_policy.should_retry(attempts)
|
||||||
|
|
||||||
|
if delay:
|
||||||
|
if delay is not True:
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
self._logger.error(
|
||||||
|
f"auto-reconnect failed {attempts} time(s); giving up"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _try_fill_write(self) -> None:
|
||||||
if not self._requests:
|
if not self._requests:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -283,15 +334,14 @@ class Sender:
|
|||||||
result = self._mtp.finalize()
|
result = self._mtp.finalize()
|
||||||
if result:
|
if result:
|
||||||
container_msg_id, mtp_buffer = result
|
container_msg_id, mtp_buffer = result
|
||||||
|
|
||||||
self._transport.pack(mtp_buffer, self._writer.write)
|
|
||||||
await self._writer.drain()
|
|
||||||
|
|
||||||
for request in self._requests:
|
for request in self._requests:
|
||||||
if isinstance(request.state, Serialized):
|
if isinstance(request.state, Serialized):
|
||||||
request.state = Sent(request.state.msg_id, container_msg_id)
|
request.state.container_msg_id = container_msg_id
|
||||||
|
|
||||||
def _try_timeout_ping(self) -> None:
|
self._transport.pack(mtp_buffer, self._writer.write)
|
||||||
|
self._write_drain_pending = True
|
||||||
|
|
||||||
|
def _try_ping_timeout(self) -> None:
|
||||||
current_time = asyncio.get_running_loop().time()
|
current_time = asyncio.get_running_loop().time()
|
||||||
|
|
||||||
if current_time >= self._next_ping:
|
if current_time >= self._next_ping:
|
||||||
@ -321,6 +371,45 @@ class Sender:
|
|||||||
del self._read_buffer[:n]
|
del self._read_buffer[:n]
|
||||||
self._process_mtp_buffer()
|
self._process_mtp_buffer()
|
||||||
|
|
||||||
|
def _on_net_write(self) -> None:
|
||||||
|
for req in self._requests:
|
||||||
|
if isinstance(req.state, Serialized):
|
||||||
|
req.state = Sent(req.state.msg_id, req.state.container_msg_id)
|
||||||
|
|
||||||
|
async def _on_error(self, error: Exception) -> None:
|
||||||
|
self._logger.info(f"handling error: {error}")
|
||||||
|
self._transport.reset()
|
||||||
|
self._mtp.reset()
|
||||||
|
self._logger.info(
|
||||||
|
"resetting sender state from read_buffer {}, mtp_buffer {}".format(
|
||||||
|
len(self._read_buffer),
|
||||||
|
len(self._mtp_buffer),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._read_buffer.clear()
|
||||||
|
self._mtp_buffer.clear()
|
||||||
|
|
||||||
|
if isinstance(error, struct.error) and self._reconnection_policy.should_retry(
|
||||||
|
0
|
||||||
|
):
|
||||||
|
self._logger.info(f"read error occurred: {error}")
|
||||||
|
await self._try_connect()
|
||||||
|
|
||||||
|
for req in self._requests:
|
||||||
|
req.state = NotSerialized()
|
||||||
|
|
||||||
|
self._updates.append(UpdatesTooLong())
|
||||||
|
return
|
||||||
|
|
||||||
|
self._logger.warning(
|
||||||
|
f"marking all {len(self._requests)} request(s) as failed: {error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for req in self._requests:
|
||||||
|
req.result.set_exception(error)
|
||||||
|
|
||||||
|
raise error
|
||||||
|
|
||||||
def _process_mtp_buffer(self) -> None:
|
def _process_mtp_buffer(self) -> None:
|
||||||
results = self._mtp.deserialize(self._mtp_buffer)
|
results = self._mtp.deserialize(self._mtp_buffer)
|
||||||
|
|
||||||
@ -331,8 +420,14 @@ class Sender:
|
|||||||
self._process_result(result)
|
self._process_result(result)
|
||||||
elif isinstance(result, RpcError):
|
elif isinstance(result, RpcError):
|
||||||
self._process_error(result)
|
self._process_error(result)
|
||||||
else:
|
elif isinstance(result, BadMessageError):
|
||||||
self._process_bad_message(result)
|
self._process_bad_message(result)
|
||||||
|
elif isinstance(result, DeserializationFailure):
|
||||||
|
self._process_deserialize_error(result)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"unexpected result type {type(result).__name__}: {result}"
|
||||||
|
)
|
||||||
|
|
||||||
def _process_update(self, update: bytes | bytearray | memoryview) -> None:
|
def _process_update(self, update: bytes | bytearray | memoryview) -> None:
|
||||||
try:
|
try:
|
||||||
@ -421,6 +516,17 @@ class Sender:
|
|||||||
result._caused_by = struct.unpack_from("<I", req.body)[0]
|
result._caused_by = struct.unpack_from("<I", req.body)[0]
|
||||||
req.result.set_exception(result)
|
req.result.set_exception(result)
|
||||||
|
|
||||||
|
def _process_deserialize_error(self, failure: DeserializationFailure):
|
||||||
|
req = self._pop_request(failure.msg_id)
|
||||||
|
|
||||||
|
if req:
|
||||||
|
self._logger.debug(f"got deserialization failure {failure.error}")
|
||||||
|
req.result.set_exception(failure.error)
|
||||||
|
else:
|
||||||
|
self._logger.info(
|
||||||
|
f"got deserialization failure {failure.error} but no such request is saved"
|
||||||
|
)
|
||||||
|
|
||||||
def _pop_request(self, msg_id: MsgId) -> Optional[Request[object]]:
|
def _pop_request(self, msg_id: MsgId) -> Optional[Request[object]]:
|
||||||
for i, req in enumerate(self._requests):
|
for i, req in enumerate(self._requests):
|
||||||
if isinstance(req.state, Serialized) and req.state.msg_id == msg_id:
|
if isinstance(req.state, Serialized) and req.state.msg_id == msg_id:
|
||||||
@ -459,6 +565,7 @@ async def connect(
|
|||||||
auth_key: Optional[bytes],
|
auth_key: Optional[bytes],
|
||||||
base_logger: logging.Logger,
|
base_logger: logging.Logger,
|
||||||
connector: Connector,
|
connector: Connector,
|
||||||
|
reconnection_policy: ReconnectionPolicy,
|
||||||
) -> Sender:
|
) -> Sender:
|
||||||
if auth_key is None:
|
if auth_key is None:
|
||||||
sender = await Sender.connect(
|
sender = await Sender.connect(
|
||||||
@ -467,6 +574,7 @@ async def connect(
|
|||||||
dc_id,
|
dc_id,
|
||||||
addr,
|
addr,
|
||||||
connector=connector,
|
connector=connector,
|
||||||
|
reconnection_policy=reconnection_policy,
|
||||||
base_logger=base_logger,
|
base_logger=base_logger,
|
||||||
)
|
)
|
||||||
return await generate_auth_key(sender)
|
return await generate_auth_key(sender)
|
||||||
@ -477,6 +585,7 @@ async def connect(
|
|||||||
dc_id,
|
dc_id,
|
||||||
addr,
|
addr,
|
||||||
connector=connector,
|
connector=connector,
|
||||||
|
reconnection_policy=reconnection_policy,
|
||||||
base_logger=base_logger,
|
base_logger=base_logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user