mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-06-18 19:16:43 +00:00
Return serialized container MsgId on finalize
This commit is contained in:
parent
6fd3eb2ee6
commit
c7d1a36969
@ -204,9 +204,9 @@ class Encrypted(Mtp):
|
|||||||
def _get_current_salt(self) -> int:
|
def _get_current_salt(self) -> int:
|
||||||
return self._salts[-1].salt if self._salts else 0
|
return self._salts[-1].salt if self._salts else 0
|
||||||
|
|
||||||
def _finalize_plain(self) -> bytes:
|
def _finalize_plain(self) -> Optional[Tuple[MsgId, bytes]]:
|
||||||
if not self._msg_count:
|
if not self._msg_count:
|
||||||
return b""
|
return None
|
||||||
|
|
||||||
if self._msg_count == 1:
|
if self._msg_count == 1:
|
||||||
del self._buffer[:CONTAINER_HEADER_LEN]
|
del self._buffer[:CONTAINER_HEADER_LEN]
|
||||||
@ -216,7 +216,7 @@ class Encrypted(Mtp):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self._msg_count == 1:
|
if self._msg_count == 1:
|
||||||
container_msg_id: Union[Type[Single], int] = Single
|
container_msg_id = self._last_msg_id
|
||||||
else:
|
else:
|
||||||
container_msg_id = self._get_new_msg_id()
|
container_msg_id = self._get_new_msg_id()
|
||||||
self._buffer[HEADER_LEN : HEADER_LEN + CONTAINER_HEADER_LEN] = struct.pack(
|
self._buffer[HEADER_LEN : HEADER_LEN + CONTAINER_HEADER_LEN] = struct.pack(
|
||||||
@ -235,7 +235,7 @@ class Encrypted(Mtp):
|
|||||||
self._msg_count = 0
|
self._msg_count = 0
|
||||||
result = bytes(self._buffer)
|
result = bytes(self._buffer)
|
||||||
self._buffer.clear()
|
self._buffer.clear()
|
||||||
return result
|
return MsgId(container_msg_id), result
|
||||||
|
|
||||||
def _process_message(self, message: Message) -> None:
|
def _process_message(self, message: Message) -> None:
|
||||||
if message_requires_ack(message):
|
if message_requires_ack(message):
|
||||||
@ -465,12 +465,13 @@ class Encrypted(Mtp):
|
|||||||
|
|
||||||
return self._serialize_msg(body, True)
|
return self._serialize_msg(body, True)
|
||||||
|
|
||||||
def finalize(self) -> bytes:
|
def finalize(self) -> Optional[Tuple[MsgId, bytes]]:
|
||||||
buffer = self._finalize_plain()
|
result = self._finalize_plain()
|
||||||
if not buffer:
|
if not result:
|
||||||
return buffer
|
return None
|
||||||
else:
|
|
||||||
return encrypt_data_v2(buffer, self._auth_key)
|
msg_id, buffer = result
|
||||||
|
return msg_id, encrypt_data_v2(buffer, self._auth_key)
|
||||||
|
|
||||||
def deserialize(self, payload: bytes) -> Deserialization:
|
def deserialize(self, payload: bytes) -> Deserialization:
|
||||||
check_message_buffer(payload)
|
check_message_buffer(payload)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import struct
|
import struct
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from ..utils import check_message_buffer
|
from ..utils import check_message_buffer
|
||||||
from .types import Deserialization, MsgId, Mtp
|
from .types import Deserialization, MsgId, Mtp
|
||||||
@ -23,10 +23,13 @@ class Plain(Mtp):
|
|||||||
self._buffer += request # message_data
|
self._buffer += request # message_data
|
||||||
return msg_id
|
return msg_id
|
||||||
|
|
||||||
def finalize(self) -> bytes:
|
def finalize(self) -> Optional[Tuple[MsgId, bytes]]:
|
||||||
|
if not self._buffer:
|
||||||
|
return None
|
||||||
|
|
||||||
result = bytes(self._buffer)
|
result = bytes(self._buffer)
|
||||||
self._buffer.clear()
|
self._buffer.clear()
|
||||||
return result
|
return MsgId(0), result
|
||||||
|
|
||||||
def deserialize(self, payload: bytes) -> Deserialization:
|
def deserialize(self, payload: bytes) -> Deserialization:
|
||||||
check_message_buffer(payload)
|
check_message_buffer(payload)
|
||||||
|
@ -165,12 +165,23 @@ class Deserialization:
|
|||||||
class Mtp(ABC):
|
class Mtp(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def push(self, request: bytes) -> Optional[MsgId]:
|
def push(self, request: bytes) -> Optional[MsgId]:
|
||||||
pass
|
"""
|
||||||
|
Push a request's body to the internal buffer.
|
||||||
|
|
||||||
|
On success, return the serialized message identifier.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def finalize(self) -> bytes:
|
def finalize(self) -> Optional[Tuple[MsgId, bytes]]:
|
||||||
pass
|
"""
|
||||||
|
Finalize the buffer of serialized requests.
|
||||||
|
|
||||||
|
If the buffer is empty, :data:`None` is returned.
|
||||||
|
Otherwise, the message identifier for the entire buffer and the serialized buffer are returned.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def deserialize(self, payload: bytes) -> Deserialization:
|
def deserialize(self, payload: bytes) -> Deserialization:
|
||||||
pass
|
"""
|
||||||
|
Deserialize incoming buffer payload.
|
||||||
|
"""
|
||||||
|
@ -271,8 +271,9 @@ class Sender:
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
mtp_buffer = self._mtp.finalize()
|
result = self._mtp.finalize()
|
||||||
if mtp_buffer:
|
if result:
|
||||||
|
_, mtp_buffer = result
|
||||||
self._transport.pack(mtp_buffer, self._writer.write)
|
self._transport.pack(mtp_buffer, self._writer.write)
|
||||||
self._write_drain_pending = True
|
self._write_drain_pending = True
|
||||||
|
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
import struct
|
import struct
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from pytest import raises
|
from pytest import raises
|
||||||
from telethon._impl.crypto import AuthKey
|
from telethon._impl.crypto import AuthKey
|
||||||
from telethon._impl.mtproto import Encrypted, Plain, RpcError
|
from telethon._impl.mtproto import Encrypted, Plain, RpcError
|
||||||
|
from telethon._impl.mtproto.mtp.types import MsgId
|
||||||
from telethon._impl.tl.mtproto.types import RpcError as GeneratedRpcError
|
from telethon._impl.tl.mtproto.types import RpcError as GeneratedRpcError
|
||||||
|
|
||||||
|
|
||||||
@ -47,14 +49,20 @@ def test_rpc_error_parsing() -> None:
|
|||||||
PLAIN_REQUEST = b"Hey!"
|
PLAIN_REQUEST = b"Hey!"
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap_finalize(finalized: Optional[Tuple[MsgId, bytes]]) -> bytes:
|
||||||
|
assert finalized is not None
|
||||||
|
_, buffer = finalized
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
def test_plain_finalize_clears_buffer() -> None:
|
def test_plain_finalize_clears_buffer() -> None:
|
||||||
mtp = Plain()
|
mtp = Plain()
|
||||||
|
|
||||||
mtp.push(PLAIN_REQUEST)
|
mtp.push(PLAIN_REQUEST)
|
||||||
assert len(mtp.finalize()) == 24
|
assert len(unwrap_finalize(mtp.finalize())) == 24
|
||||||
|
|
||||||
mtp.push(PLAIN_REQUEST)
|
mtp.push(PLAIN_REQUEST)
|
||||||
assert len(mtp.finalize()) == 24
|
assert len(unwrap_finalize(mtp.finalize())) == 24
|
||||||
|
|
||||||
|
|
||||||
def test_plain_only_one_push_allowed() -> None:
|
def test_plain_only_one_push_allowed() -> None:
|
||||||
@ -90,7 +98,7 @@ def test_serialization_has_salt_client_id() -> None:
|
|||||||
mtp = Encrypted(auth_key())
|
mtp = Encrypted(auth_key())
|
||||||
|
|
||||||
mtp.push(REQUEST)
|
mtp.push(REQUEST)
|
||||||
buffer = mtp._finalize_plain()
|
buffer = unwrap_finalize(mtp._finalize_plain())
|
||||||
|
|
||||||
# salt
|
# salt
|
||||||
assert buffer[0:8] == bytes(8)
|
assert buffer[0:8] == bytes(8)
|
||||||
@ -104,7 +112,7 @@ def test_correct_single_serialization() -> None:
|
|||||||
mtp = Encrypted(auth_key())
|
mtp = Encrypted(auth_key())
|
||||||
|
|
||||||
assert mtp.push(REQUEST) is not None
|
assert mtp.push(REQUEST) is not None
|
||||||
buffer = mtp._finalize_plain()
|
buffer = unwrap_finalize(mtp._finalize_plain())
|
||||||
|
|
||||||
ensure_buffer_is_message(buffer[MESSAGE_PREFIX_LEN:], REQUEST, 1)
|
ensure_buffer_is_message(buffer[MESSAGE_PREFIX_LEN:], REQUEST, 1)
|
||||||
|
|
||||||
@ -114,7 +122,7 @@ def test_correct_multi_serialization() -> None:
|
|||||||
|
|
||||||
assert mtp.push(REQUEST) is not None
|
assert mtp.push(REQUEST) is not None
|
||||||
assert mtp.push(REQUEST_B) is not None
|
assert mtp.push(REQUEST_B) is not None
|
||||||
buffer = mtp._finalize_plain()
|
buffer = unwrap_finalize(mtp._finalize_plain())
|
||||||
buffer = buffer[MESSAGE_PREFIX_LEN:]
|
buffer = buffer[MESSAGE_PREFIX_LEN:]
|
||||||
|
|
||||||
# container msg_id
|
# container msg_id
|
||||||
@ -138,7 +146,7 @@ def test_correct_single_large_serialization() -> None:
|
|||||||
data = bytes(0x7F for _ in range(768 * 1024))
|
data = bytes(0x7F for _ in range(768 * 1024))
|
||||||
|
|
||||||
assert mtp.push(data) is not None
|
assert mtp.push(data) is not None
|
||||||
buffer = mtp._finalize_plain()
|
buffer = unwrap_finalize(mtp._finalize_plain())
|
||||||
|
|
||||||
buffer = buffer[MESSAGE_PREFIX_LEN:]
|
buffer = buffer[MESSAGE_PREFIX_LEN:]
|
||||||
assert len(buffer) == 16 + len(data)
|
assert len(buffer) == 16 + len(data)
|
||||||
@ -151,7 +159,7 @@ def test_correct_multi_large_serialization() -> None:
|
|||||||
assert mtp.push(data) is not None
|
assert mtp.push(data) is not None
|
||||||
assert mtp.push(data) is None
|
assert mtp.push(data) is None
|
||||||
|
|
||||||
buffer = mtp._finalize_plain()
|
buffer = unwrap_finalize(mtp._finalize_plain())
|
||||||
buffer = buffer[MESSAGE_PREFIX_LEN:]
|
buffer = buffer[MESSAGE_PREFIX_LEN:]
|
||||||
assert len(buffer) == 16 + len(data)
|
assert len(buffer) == 16 + len(data)
|
||||||
|
|
||||||
@ -173,22 +181,22 @@ def test_non_padded_payload_panics() -> None:
|
|||||||
def test_no_compression_is_honored() -> None:
|
def test_no_compression_is_honored() -> None:
|
||||||
mtp = Encrypted(auth_key(), compression_threshold=None)
|
mtp = Encrypted(auth_key(), compression_threshold=None)
|
||||||
mtp.push(bytes(512 * 1024))
|
mtp.push(bytes(512 * 1024))
|
||||||
buffer = mtp._finalize_plain()
|
buffer = unwrap_finalize(mtp._finalize_plain())
|
||||||
assert GZIP_PACKED_HEADER not in buffer
|
assert GZIP_PACKED_HEADER not in buffer
|
||||||
|
|
||||||
|
|
||||||
def test_some_compression() -> None:
|
def test_some_compression() -> None:
|
||||||
mtp = Encrypted(auth_key(), compression_threshold=768 * 1024)
|
mtp = Encrypted(auth_key(), compression_threshold=768 * 1024)
|
||||||
mtp.push(bytes(512 * 1024))
|
mtp.push(bytes(512 * 1024))
|
||||||
buffer = mtp._finalize_plain()
|
buffer = unwrap_finalize(mtp._finalize_plain())
|
||||||
assert GZIP_PACKED_HEADER not in buffer
|
assert GZIP_PACKED_HEADER not in buffer
|
||||||
|
|
||||||
mtp = Encrypted(auth_key(), compression_threshold=256 * 1024)
|
mtp = Encrypted(auth_key(), compression_threshold=256 * 1024)
|
||||||
mtp.push(bytes(512 * 1024))
|
mtp.push(bytes(512 * 1024))
|
||||||
buffer = mtp._finalize_plain()
|
buffer = unwrap_finalize(mtp._finalize_plain())
|
||||||
assert GZIP_PACKED_HEADER in buffer
|
assert GZIP_PACKED_HEADER in buffer
|
||||||
|
|
||||||
mtp = Encrypted(auth_key())
|
mtp = Encrypted(auth_key())
|
||||||
mtp.push(bytes(512 * 1024))
|
mtp.push(bytes(512 * 1024))
|
||||||
buffer = mtp._finalize_plain()
|
buffer = unwrap_finalize(mtp._finalize_plain())
|
||||||
assert GZIP_PACKED_HEADER in buffer
|
assert GZIP_PACKED_HEADER in buffer
|
||||||
|
Loading…
Reference in New Issue
Block a user