Make TLMessage always have a valid TLObject

This simplifies the flow instead of having separate request/body
attributes, and also means that BinaryReader.tgread_object() can
be used without so many special cases.
This commit is contained in:
Lonami Exo 2018-06-09 13:48:27 +02:00
parent f7e8907c6f
commit be279ce3f5
6 changed files with 98 additions and 91 deletions

View File

@ -12,14 +12,15 @@ class TypeNotFoundError(Exception):
Occurs when a type is not found, for example, Occurs when a type is not found, for example,
when trying to read a TLObject with an invalid constructor code. when trying to read a TLObject with an invalid constructor code.
""" """
def __init__(self, invalid_constructor_id): def __init__(self, invalid_constructor_id, remaining):
super().__init__( super().__init__(
'Could not find a matching Constructor ID for the TLObject ' 'Could not find a matching Constructor ID for the TLObject '
'that was supposed to be read with ID {}. Most likely, a TLObject ' 'that was supposed to be read with ID {:08x}. Most likely, '
'was trying to be read when it should not be read.' 'a TLObject was trying to be read when it should not be read. '
.format(hex(invalid_constructor_id))) 'Remaining bytes: {!r}'.format(invalid_constructor_id, remaining))
self.invalid_constructor_id = invalid_constructor_id self.invalid_constructor_id = invalid_constructor_id
self.remaining = remaining
class InvalidChecksumError(Exception): class InvalidChecksumError(Exception):

View File

@ -141,7 +141,10 @@ class BinaryReader:
if clazz is None: if clazz is None:
# If there was still no luck, give up # If there was still no luck, give up
self.seek(-4) # Go back self.seek(-4) # Go back
raise TypeNotFoundError(constructor_id) pos = self.tell_position()
error = TypeNotFoundError(constructor_id, self.read())
self.set_position(pos)
raise error
return clazz.from_reader(self) return clazz.from_reader(self)

View File

@ -351,29 +351,27 @@ class MTProtoSender:
__log__.warning('Security error while unpacking a ' __log__.warning('Security error while unpacking a '
'received message:'.format(e)) 'received message:'.format(e))
continue continue
except TypeNotFoundError as e:
# The payload inside the message was not a known TLObject.
__log__.info('Server replied with an unknown type {:08x}: {!r}'
.format(e.invalid_constructor_id, e.remaining))
else: else:
try: await self._process_message(message)
with BinaryReader(message.body) as reader:
obj = reader.tgread_object()
except TypeNotFoundError as e:
__log__.warning('Could not decode received message: {}, '
'raw bytes: {!r}'.format(e, message))
else:
await self._process_message(message, obj)
# Response Handlers # Response Handlers
async def _process_message(self, message, obj): async def _process_message(self, message):
""" """
Adds the given message to the list of messages that must be Adds the given message to the list of messages that must be
acknowledged and dispatches control to different ``_handle_*`` acknowledged and dispatches control to different ``_handle_*``
method based on its type. method based on its type.
""" """
self._pending_ack.add(message.msg_id) self._pending_ack.add(message.msg_id)
handler = self._handlers.get(obj.CONSTRUCTOR_ID, self._handle_update) handler = self._handlers.get(message.obj.CONSTRUCTOR_ID,
await handler(message, obj) self._handle_update)
await handler(message)
async def _handle_rpc_result(self, message, rpc_result): async def _handle_rpc_result(self, message):
""" """
Handles the result for Remote Procedure Calls: Handles the result for Remote Procedure Calls:
@ -381,6 +379,7 @@ class MTProtoSender:
This is where the future results for sent requests are set. This is where the future results for sent requests are set.
""" """
rpc_result = message.obj
message = self._pending_messages.pop(rpc_result.req_msg_id, None) message = self._pending_messages.pop(rpc_result.req_msg_id, None)
__log__.debug('Handling RPC result for message {}' __log__.debug('Handling RPC result for message {}'
.format(rpc_result.req_msg_id)) .format(rpc_result.req_msg_id))
@ -397,7 +396,7 @@ class MTProtoSender:
return return
elif message: elif message:
with BinaryReader(rpc_result.body) as reader: with BinaryReader(rpc_result.body) as reader:
result = message.request.read_result(reader) result = message.obj.read_result(reader)
# TODO Process entities # TODO Process entities
if not message.future.cancelled(): if not message.future.cancelled():
@ -408,35 +407,35 @@ class MTProtoSender:
__log__.info('Received response without parent request: {}' __log__.info('Received response without parent request: {}'
.format(rpc_result.body)) .format(rpc_result.body))
async def _handle_container(self, message, container): async def _handle_container(self, message):
""" """
Processes the inner messages of a container with many of them: Processes the inner messages of a container with many of them:
msg_container#73f1f8dc messages:vector<%Message> = MessageContainer; msg_container#73f1f8dc messages:vector<%Message> = MessageContainer;
""" """
__log__.debug('Handling container') __log__.debug('Handling container')
for inner_message in container.messages: for inner_message in message.obj.messages:
with BinaryReader(inner_message.body) as reader: await self._process_message(inner_message)
inner_obj = reader.tgread_object()
await self._process_message(inner_message, inner_obj)
async def _handle_gzip_packed(self, message, gzip_packed): async def _handle_gzip_packed(self, message):
""" """
Unpacks the data from a gzipped object and processes it: Unpacks the data from a gzipped object and processes it:
gzip_packed#3072cfa1 packed_data:bytes = Object; gzip_packed#3072cfa1 packed_data:bytes = Object;
""" """
__log__.debug('Handling gzipped data') __log__.debug('Handling gzipped data')
with BinaryReader(gzip_packed.data) as reader: with BinaryReader(message.obj.data) as reader:
await self._process_message(message, reader.tgread_object()) message.obj = reader.tgread_object()
await self._process_message(message)
async def _handle_update(self, message, update): async def _handle_update(self, message):
__log__.debug('Handling update {}'.format(update.__class__.__name__)) __log__.debug('Handling update {}'
.format(message.obj.__class__.__name__))
# TODO Further handling of the update # TODO Further handling of the update
# TODO Process entities # TODO Process entities
async def _handle_pong(self, message, pong): async def _handle_pong(self, message):
""" """
Handles pong results, which don't come inside a ``rpc_result`` Handles pong results, which don't come inside a ``rpc_result``
but are still sent through a request: but are still sent through a request:
@ -444,11 +443,12 @@ class MTProtoSender:
pong#347773c5 msg_id:long ping_id:long = Pong; pong#347773c5 msg_id:long ping_id:long = Pong;
""" """
__log__.debug('Handling pong') __log__.debug('Handling pong')
pong = message.obj
message = self._pending_messages.pop(pong.msg_id, None) message = self._pending_messages.pop(pong.msg_id, None)
if message: if message:
message.future.set_result(pong) message.future.set_result(pong.obj)
async def _handle_bad_server_salt(self, message, bad_salt): async def _handle_bad_server_salt(self, message):
""" """
Corrects the currently used server salt to use the right value Corrects the currently used server salt to use the right value
before enqueuing the rejected message to be re-sent: before enqueuing the rejected message to be re-sent:
@ -457,10 +457,11 @@ class MTProtoSender:
error_code:int new_server_salt:long = BadMsgNotification; error_code:int new_server_salt:long = BadMsgNotification;
""" """
__log__.debug('Handling bad salt') __log__.debug('Handling bad salt')
bad_salt = message.obj
self.state.salt = bad_salt.new_server_salt self.state.salt = bad_salt.new_server_salt
await self._send_queue.put(self._pending_messages[bad_salt.bad_msg_id]) await self._send_queue.put(self._pending_messages[bad_salt.bad_msg_id])
async def _handle_bad_notification(self, message, bad_msg): async def _handle_bad_notification(self, message):
""" """
Adjusts the current state to be correct based on the Adjusts the current state to be correct based on the
received bad message notification whenever possible: received bad message notification whenever possible:
@ -469,6 +470,7 @@ class MTProtoSender:
error_code:int = BadMsgNotification; error_code:int = BadMsgNotification;
""" """
__log__.debug('Handling bad message') __log__.debug('Handling bad message')
bad_msg = message.obj
if bad_msg.error_code in (16, 17): if bad_msg.error_code in (16, 17):
# Sent msg_id too low or too high (respectively). # Sent msg_id too low or too high (respectively).
# Use the current msg_id to determine the right time offset. # Use the current msg_id to determine the right time offset.
@ -489,7 +491,7 @@ class MTProtoSender:
# Messages are to be re-sent once we've corrected the issue # Messages are to be re-sent once we've corrected the issue
await self._send_queue.put(self._pending_messages[bad_msg.bad_msg_id]) await self._send_queue.put(self._pending_messages[bad_msg.bad_msg_id])
async def _handle_detailed_info(self, message, detailed_info): async def _handle_detailed_info(self, message):
""" """
Updates the current status with the received detailed information: Updates the current status with the received detailed information:
@ -498,9 +500,9 @@ class MTProtoSender:
""" """
# TODO https://goo.gl/VvpCC6 # TODO https://goo.gl/VvpCC6
__log__.debug('Handling detailed info') __log__.debug('Handling detailed info')
self._pending_ack.add(detailed_info.answer_msg_id) self._pending_ack.add(message.obj.answer_msg_id)
async def _handle_new_detailed_info(self, message, new_detailed_info): async def _handle_new_detailed_info(self, message):
""" """
Updates the current status with the received detailed information: Updates the current status with the received detailed information:
@ -509,9 +511,9 @@ class MTProtoSender:
""" """
# TODO https://goo.gl/G7DPsR # TODO https://goo.gl/G7DPsR
__log__.debug('Handling new detailed info') __log__.debug('Handling new detailed info')
self._pending_ack.add(new_detailed_info.answer_msg_id) self._pending_ack.add(message.obj.answer_msg_id)
async def _handle_new_session_created(self, message, new_session): async def _handle_new_session_created(self, message):
""" """
Updates the current status with the received session information: Updates the current status with the received session information:
@ -520,7 +522,7 @@ class MTProtoSender:
""" """
# TODO https://goo.gl/LMyN7A # TODO https://goo.gl/LMyN7A
__log__.debug('Handling new session created') __log__.debug('Handling new session created')
self.state.salt = new_session.server_salt self.state.salt = message.obj.server_salt
def _clean_containers(self, msg_ids): def _clean_containers(self, msg_ids):
""" """
@ -533,13 +535,13 @@ class MTProtoSender:
""" """
for i in reversed(range(len(self._pending_containers))): for i in reversed(range(len(self._pending_containers))):
message = self._pending_containers[i] message = self._pending_containers[i]
for msg in message.request.messages: for msg in message.obj.messages:
if msg.msg_id in msg_ids: if msg.msg_id in msg_ids:
del self._pending_containers[i] del self._pending_containers[i]
del self._pending_messages[message.msg_id] del self._pending_messages[message.msg_id]
break break
async def _handle_ack(self, message, ack): async def _handle_ack(self, message):
""" """
Handles a server acknowledge about our messages. Normally Handles a server acknowledge about our messages. Normally
these can be ignored except in the case of ``auth.logOut``: these can be ignored except in the case of ``auth.logOut``:
@ -555,16 +557,17 @@ class MTProtoSender:
messages are acknowledged. messages are acknowledged.
""" """
__log__.debug('Handling acknowledge') __log__.debug('Handling acknowledge')
ack = message.obj
if self._pending_containers: if self._pending_containers:
self._clean_containers(ack.msg_ids) self._clean_containers(ack.msg_ids)
for msg_id in ack.msg_ids: for msg_id in ack.msg_ids:
msg = self._pending_messages.get(msg_id, None) msg = self._pending_messages.get(msg_id, None)
if msg and isinstance(msg.request, LogOutRequest): if msg and isinstance(msg.obj, LogOutRequest):
del self._pending_messages[msg_id] del self._pending_messages[msg_id]
msg.future.set_result(True) msg.future.set_result(True)
async def _handle_future_salts(self, message, salts): async def _handle_future_salts(self, message):
""" """
Handles future salt results, which don't come inside a Handles future salt results, which don't come inside a
``rpc_result`` but are still sent through a request: ``rpc_result`` but are still sent through a request:
@ -577,7 +580,7 @@ class MTProtoSender:
__log__.debug('Handling future salts') __log__.debug('Handling future salts')
msg = self._pending_messages.pop(message.msg_id, None) msg = self._pending_messages.pop(message.msg_id, None)
if msg: if msg:
msg.future.set_result(salts) msg.future.set_result(message.obj)
class _ContainerQueue(asyncio.Queue): class _ContainerQueue(asyncio.Queue):
@ -593,13 +596,13 @@ class _ContainerQueue(asyncio.Queue):
""" """
async def get(self): async def get(self):
result = await super().get() result = await super().get()
if self.empty() or isinstance(result.request, MessageContainer): if self.empty() or isinstance(result.obj, MessageContainer):
return result return result
result = [result] result = [result]
while not self.empty(): while not self.empty():
item = self.get_nowait() item = self.get_nowait()
if isinstance(item.request, MessageContainer): if isinstance(item.obj, MessageContainer):
await self.put(item) await self.put(item)
break break
else: else:

View File

@ -1,3 +1,4 @@
import logging
import os import os
import struct import struct
import time import time
@ -8,6 +9,8 @@ from ..errors import SecurityError, BrokenAuthKeyError
from ..extensions import BinaryReader from ..extensions import BinaryReader
from ..tl.core import TLMessage from ..tl.core import TLMessage
__log__ = logging.getLogger(__name__)
class MTProtoState: class MTProtoState:
""" """
@ -33,15 +36,15 @@ class MTProtoState:
self._sequence = 0 self._sequence = 0
self._last_msg_id = 0 self._last_msg_id = 0
def create_message(self, request, after=None): def create_message(self, obj, after=None):
""" """
Creates a new `telethon.tl.tl_message.TLMessage` from Creates a new `telethon.tl.tl_message.TLMessage` from
the given `telethon.tl.tlobject.TLObject` instance. the given `telethon.tl.tlobject.TLObject` instance.
""" """
return TLMessage( return TLMessage(
msg_id=self._get_new_msg_id(), msg_id=self._get_new_msg_id(),
seq_no=self._get_seq_no(request.content_related), seq_no=self._get_seq_no(obj.content_related),
request=request, obj=obj,
after_id=after.msg_id if after else None after_id=after.msg_id if after else None
) )
@ -100,25 +103,31 @@ class MTProtoState:
msg_key = body[8:24] msg_key = body[8:24]
aes_key, aes_iv = self._calc_key(self.auth_key.key, msg_key, False) aes_key, aes_iv = self._calc_key(self.auth_key.key, msg_key, False)
data = BinaryReader(AES.decrypt_ige(body[24:], aes_key, aes_iv)) body = AES.decrypt_ige(body[24:], aes_key, aes_iv)
data.read_long() # remote_salt
if data.read_long() != self.id:
raise SecurityError('Server replied with a wrong session ID')
remote_msg_id = data.read_long()
remote_sequence = data.read_int()
msg_len = data.read_int()
message = data.read(msg_len)
# https://core.telegram.org/mtproto/security_guidelines # https://core.telegram.org/mtproto/security_guidelines
# Sections "checking sha256 hash" and "message length" # Sections "checking sha256 hash" and "message length"
our_key = sha256(self.auth_key.key[96:96 + 32] + data.get_bytes()) our_key = sha256(self.auth_key.key[96:96 + 32] + body)
if msg_key != our_key.digest()[8:24]: if msg_key != our_key.digest()[8:24]:
raise SecurityError( raise SecurityError(
"Received msg_key doesn't match with expected one") "Received msg_key doesn't match with expected one")
return TLMessage(remote_msg_id, remote_sequence, body=message) reader = BinaryReader(body)
reader.read_long() # remote_salt
if reader.read_long() != self.id:
raise SecurityError('Server replied with a wrong session ID')
remote_msg_id = reader.read_long()
remote_sequence = reader.read_int()
msg_len = reader.read_int()
before = reader.tell_position()
obj = reader.tgread_object()
if reader.tell_position() != before + msg_len:
reader.set_position(before)
__log__.warning('Data left after TLObject {}: {!r}'
.format(obj, reader.read(msg_len)))
return TLMessage(remote_msg_id, remote_sequence, obj)
def _get_new_msg_id(self): def _get_new_msg_id(self):
""" """

View File

@ -1,7 +1,10 @@
import logging
import struct import struct
from ..tlobject import TLObject
from .tlmessage import TLMessage from .tlmessage import TLMessage
from ..tlobject import TLObject
__log__ = logging.getLogger(__name__)
class MessageContainer(TLObject): class MessageContainer(TLObject):
@ -26,17 +29,6 @@ class MessageContainer(TLObject):
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages) '<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages)
) + b''.join(bytes(m) for m in self.messages) ) + b''.join(bytes(m) for m in self.messages)
@staticmethod
def iter_read(reader):
reader.read_int(signed=False) # code
size = reader.read_int()
for _ in range(size):
inner_msg_id = reader.read_long()
inner_sequence = reader.read_int()
inner_length = reader.read_int()
yield TLMessage(inner_msg_id, inner_sequence,
body=reader.read(inner_length))
def __str__(self): def __str__(self):
return TLObject.pretty_format(self) return TLObject.pretty_format(self)
@ -46,8 +38,16 @@ class MessageContainer(TLObject):
@classmethod @classmethod
def from_reader(cls, reader): def from_reader(cls, reader):
# This assumes that .read_* calls are done in the order they appear # This assumes that .read_* calls are done in the order they appear
return MessageContainer([TLMessage( messages = []
msg_id=reader.read_long(), for _ in range(reader.read_int()):
seq_no=reader.read_int(), msg_id = reader.read_long()
body=reader.read(reader.read_int()) seq_no = reader.read_int()
) for _ in range(reader.read_int())]) length = reader.read_int()
before = reader.tell_position()
obj = reader.tgread_object()
messages.append(TLMessage(msg_id, seq_no, obj))
if reader.tell_position() != before + length:
reader.set_position(before)
__log__.warning('Data left after TLObject {}: {!r}'
.format(obj, reader.read(length)))
return MessageContainer(messages)

View File

@ -21,23 +21,14 @@ class TLMessage(TLObject):
sent `TLMessage`, and this result can be represented as a `Future` sent `TLMessage`, and this result can be represented as a `Future`
that will eventually be set with either a result, error or cancelled. that will eventually be set with either a result, error or cancelled.
""" """
def __init__(self, msg_id, seq_no, body=None, request=None, after_id=0): def __init__(self, msg_id, seq_no, obj=None, after_id=0):
super().__init__() super().__init__()
self.msg_id = msg_id self.msg_id = msg_id
self.seq_no = seq_no self.seq_no = seq_no
self.obj = obj
self.container_msg_id = None self.container_msg_id = None
self.future = asyncio.Future() self.future = asyncio.Future()
# TODO Perhaps it's possible to merge body and request?
# We need things like rpc_result and gzip_packed to
# be readable by the ``BinaryReader`` for such purpose.
# Used for incoming, not-decoded messages
self.body = body
# Used for outgoing, not-encoded messages
self.request = request
# After which message ID this one should run. We do this so # After which message ID this one should run. We do this so
# InvokeAfterMsgRequest is transparent to the user and we can # InvokeAfterMsgRequest is transparent to the user and we can
# easily invoke after while confirming the original request. # easily invoke after while confirming the original request.
@ -47,17 +38,17 @@ class TLMessage(TLObject):
return { return {
'msg_id': self.msg_id, 'msg_id': self.msg_id,
'seq_no': self.seq_no, 'seq_no': self.seq_no,
'request': self.request, 'obj': self.obj,
'container_msg_id': self.container_msg_id, 'container_msg_id': self.container_msg_id,
'after_id': self.after_id 'after_id': self.after_id
} }
def __bytes__(self): def __bytes__(self):
if self.after_id is None: if self.after_id is None:
body = GzipPacked.gzip_if_smaller(self.request) body = GzipPacked.gzip_if_smaller(self.obj)
else: else:
body = GzipPacked.gzip_if_smaller( body = GzipPacked.gzip_if_smaller(
InvokeAfterMsgRequest(self.after_id, self.request)) InvokeAfterMsgRequest(self.after_id, self.obj))
return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body