Create RpcResult class and generalise core special cases

This results in a cleaner MTProtoSender, which now can always
read a TLObject with a guaranteed item, if the message is OK.
This commit is contained in:
Lonami Exo 2018-06-09 13:11:49 +02:00
parent 1e66cea9b7
commit f7e8907c6f
11 changed files with 132 additions and 84 deletions

View File

@ -40,49 +40,49 @@ def report_error(code, message, report_method):
"We really don't want to crash when just reporting an error" "We really don't want to crash when just reporting an error"
def rpc_message_to_error(code, message, report_method=None): def rpc_message_to_error(rpc_error, report_method=None):
""" """
Converts a Telegram's RPC Error to a Python error. Converts a Telegram's RPC Error to a Python error.
:param code: the integer code of the error (like 400). :param rpc_error: the RpcError instance.
:param message: the message representing the error.
:param report_method: if present, the ID of the method that caused it. :param report_method: if present, the ID of the method that caused it.
:return: the RPCError as a Python exception that represents this error. :return: the RPCError as a Python exception that represents this error.
""" """
if report_method is not None: if report_method is not None:
Thread( Thread(
target=report_error, target=report_error,
args=(code, message, report_method) args=(rpc_error.error_code, rpc_error.error_message, report_method)
).start() ).start()
# Try to get the error by direct look-up, otherwise regex # Try to get the error by direct look-up, otherwise regex
# TODO Maybe regexes could live in a separate dictionary? # TODO Maybe regexes could live in a separate dictionary?
cls = rpc_errors_all.get(message, None) cls = rpc_errors_all.get(rpc_error.error_message, None)
if cls: if cls:
return cls() return cls()
for msg_regex, cls in rpc_errors_all.items(): for msg_regex, cls in rpc_errors_all.items():
m = re.match(msg_regex, message) m = re.match(msg_regex, rpc_error.error_message)
if m: if m:
capture = int(m.group(1)) if m.groups() else None capture = int(m.group(1)) if m.groups() else None
return cls(capture=capture) return cls(capture=capture)
if code == 400: if rpc_error.error_code == 400:
return BadRequestError(message) return BadRequestError(rpc_error.error_message)
if code == 401: if rpc_error.error_code == 401:
return UnauthorizedError(message) return UnauthorizedError(rpc_error.error_message)
if code == 403: if rpc_error.error_code == 403:
return ForbiddenError(message) return ForbiddenError(rpc_error.error_message)
if code == 404: if rpc_error.error_code == 404:
return NotFoundError(message) return NotFoundError(rpc_error.error_message)
if code == 406: if rpc_error.error_code == 406:
return AuthKeyError(message) return AuthKeyError(rpc_error.error_message)
if code == 500: if rpc_error.error_code == 500:
return ServerError(message) return ServerError(rpc_error.error_message)
return RPCError('{} (code {})'.format(message, code)) return RPCError('{} (code {})'.format(
rpc_error.error_message, rpc_error.error_code))

View File

@ -8,6 +8,7 @@ from struct import unpack
from ..errors import TypeNotFoundError from ..errors import TypeNotFoundError
from ..tl.all_tlobjects import tlobjects from ..tl.all_tlobjects import tlobjects
from ..tl.core import core_objects
class BinaryReader: class BinaryReader:
@ -136,9 +137,11 @@ class BinaryReader:
elif value == 0x1cb5c415: # Vector elif value == 0x1cb5c415: # Vector
return [self.tgread_object() for _ in range(self.read_int())] return [self.tgread_object() for _ in range(self.read_int())]
# If there was still no luck, give up clazz = core_objects.get(constructor_id, None)
self.seek(-4) # Go back if clazz is None:
raise TypeNotFoundError(constructor_id) # If there was still no luck, give up
self.seek(-4) # Go back
raise TypeNotFoundError(constructor_id)
return clazz.from_reader(self) return clazz.from_reader(self)

View File

@ -9,7 +9,7 @@ from ..errors import (
rpc_message_to_error rpc_message_to_error
) )
from ..extensions import BinaryReader from ..extensions import BinaryReader
from ..tl import MessageContainer, GzipPacked from ..tl.core import RpcResult, MessageContainer, GzipPacked
from ..tl.functions.auth import LogOutRequest from ..tl.functions.auth import LogOutRequest
from ..tl.types import ( from ..tl.types import (
MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts, MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts,
@ -80,7 +80,7 @@ class MTProtoSender:
# Jump table from response ID to method that handles it # Jump table from response ID to method that handles it
self._handlers = { self._handlers = {
0xf35c6d01: self._handle_rpc_result, RpcResult.CONSTRUCTOR_ID: self._handle_rpc_result,
MessageContainer.CONSTRUCTOR_ID: self._handle_container, MessageContainer.CONSTRUCTOR_ID: self._handle_container,
GzipPacked.CONSTRUCTOR_ID: self._handle_gzip_packed, GzipPacked.CONSTRUCTOR_ID: self._handle_gzip_packed,
Pong.CONSTRUCTOR_ID: self._handle_pong, Pong.CONSTRUCTOR_ID: self._handle_pong,
@ -354,26 +354,26 @@ class MTProtoSender:
else: else:
try: try:
with BinaryReader(message.body) as reader: with BinaryReader(message.body) as reader:
await self._process_message(message, reader) obj = reader.tgread_object()
except TypeNotFoundError as e: except TypeNotFoundError as e:
__log__.warning('Could not decode received message: {}, ' __log__.warning('Could not decode received message: {}, '
'raw bytes: {!r}'.format(e, message)) 'raw bytes: {!r}'.format(e, message))
else:
await self._process_message(message, obj)
# Response Handlers # Response Handlers
async def _process_message(self, message, reader): async def _process_message(self, message, obj):
""" """
Adds the given message to the list of messages that must be Adds the given message to the list of messages that must be
acknowledged and dispatches control to different ``_handle_*`` acknowledged and dispatches control to different ``_handle_*``
method based on its type. method based on its type.
""" """
self._pending_ack.add(message.msg_id) self._pending_ack.add(message.msg_id)
code = reader.read_int(signed=False) handler = self._handlers.get(obj.CONSTRUCTOR_ID, self._handle_update)
reader.seek(-4) await handler(message, obj)
handler = self._handlers.get(code, self._handle_update)
await handler(message, reader)
async def _handle_rpc_result(self, message, reader): async def _handle_rpc_result(self, message, rpc_result):
""" """
Handles the result for Remote Procedure Calls: Handles the result for Remote Procedure Calls:
@ -381,20 +381,13 @@ class MTProtoSender:
This is where the future results for sent requests are set. This is where the future results for sent requests are set.
""" """
# TODO Don't make this a special cased object message = self._pending_messages.pop(rpc_result.req_msg_id, None)
reader.read_int(signed=False) # code __log__.debug('Handling RPC result for message {}'
message_id = reader.read_long() .format(rpc_result.req_msg_id))
inner_code = reader.read_int(signed=False)
reader.seek(-4)
__log__.debug('Handling RPC result for message {}'.format(message_id)) if rpc_result.error:
message = self._pending_messages.pop(message_id, None)
if inner_code == 0x2144ca19: # RPC Error
# TODO Report errors if possible/enabled # TODO Report errors if possible/enabled
reader.seek(4) error = rpc_message_to_error(rpc_result.error)
error = rpc_message_to_error(reader.read_int(),
reader.tgread_string())
await self._send_queue.put(self.state.create_message( await self._send_queue.put(self.state.create_message(
MsgsAck([message.msg_id]) MsgsAck([message.msg_id])
)) ))
@ -403,10 +396,7 @@ class MTProtoSender:
message.future.set_exception(error) message.future.set_exception(error)
return return
elif message: elif message:
if inner_code == GzipPacked.CONSTRUCTOR_ID: with BinaryReader(rpc_result.body) as reader:
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
result = message.request.read_result(compressed_reader)
else:
result = message.request.read_result(reader) result = message.request.read_result(reader)
# TODO Process entities # TODO Process entities
@ -416,37 +406,37 @@ class MTProtoSender:
else: else:
# TODO We should not get responses to things we never sent # TODO We should not get responses to things we never sent
__log__.info('Received response without parent request: {}' __log__.info('Received response without parent request: {}'
.format(reader.tgread_object())) .format(rpc_result.body))
async def _handle_container(self, message, reader): async def _handle_container(self, message, container):
""" """
Processes the inner messages of a container with many of them: Processes the inner messages of a container with many of them:
msg_container#73f1f8dc messages:vector<%Message> = MessageContainer; msg_container#73f1f8dc messages:vector<%Message> = MessageContainer;
""" """
__log__.debug('Handling container') __log__.debug('Handling container')
for inner_message in MessageContainer.iter_read(reader): for inner_message in container.messages:
with BinaryReader(inner_message.body) as inner_reader: with BinaryReader(inner_message.body) as reader:
await self._process_message(inner_message, inner_reader) inner_obj = reader.tgread_object()
await self._process_message(inner_message, inner_obj)
async def _handle_gzip_packed(self, message, reader): async def _handle_gzip_packed(self, message, gzip_packed):
""" """
Unpacks the data from a gzipped object and processes it: Unpacks the data from a gzipped object and processes it:
gzip_packed#3072cfa1 packed_data:bytes = Object; gzip_packed#3072cfa1 packed_data:bytes = Object;
""" """
__log__.debug('Handling gzipped data') __log__.debug('Handling gzipped data')
with BinaryReader(GzipPacked.read(reader)) as compressed_reader: with BinaryReader(gzip_packed.data) as reader:
await self._process_message(message, compressed_reader) await self._process_message(message, reader.tgread_object())
async def _handle_update(self, message, reader): async def _handle_update(self, message, update):
obj = reader.tgread_object() __log__.debug('Handling update {}'.format(update.__class__.__name__))
__log__.debug('Handling update {}'.format(obj.__class__.__name__))
# TODO Further handling of the update # TODO Further handling of the update
# TODO Process entities # TODO Process entities
async def _handle_pong(self, message, reader): async def _handle_pong(self, message, pong):
""" """
Handles pong results, which don't come inside a ``rpc_result`` Handles pong results, which don't come inside a ``rpc_result``
but are still sent through a request: but are still sent through a request:
@ -454,12 +444,11 @@ class MTProtoSender:
pong#347773c5 msg_id:long ping_id:long = Pong; pong#347773c5 msg_id:long ping_id:long = Pong;
""" """
__log__.debug('Handling pong') __log__.debug('Handling pong')
pong = reader.tgread_object()
message = self._pending_messages.pop(pong.msg_id, None) message = self._pending_messages.pop(pong.msg_id, None)
if message: if message:
message.future.set_result(pong) message.future.set_result(pong)
async def _handle_bad_server_salt(self, message, reader): async def _handle_bad_server_salt(self, message, bad_salt):
""" """
Corrects the currently used server salt to use the right value Corrects the currently used server salt to use the right value
before enqueuing the rejected message to be re-sent: before enqueuing the rejected message to be re-sent:
@ -468,11 +457,10 @@ class MTProtoSender:
error_code:int new_server_salt:long = BadMsgNotification; error_code:int new_server_salt:long = BadMsgNotification;
""" """
__log__.debug('Handling bad salt') __log__.debug('Handling bad salt')
bad_salt = reader.tgread_object()
self.state.salt = bad_salt.new_server_salt self.state.salt = bad_salt.new_server_salt
await self._send_queue.put(self._pending_messages[bad_salt.bad_msg_id]) await self._send_queue.put(self._pending_messages[bad_salt.bad_msg_id])
async def _handle_bad_notification(self, message, reader): async def _handle_bad_notification(self, message, bad_msg):
""" """
Adjusts the current state to be correct based on the Adjusts the current state to be correct based on the
received bad message notification whenever possible: received bad message notification whenever possible:
@ -481,7 +469,6 @@ class MTProtoSender:
error_code:int = BadMsgNotification; error_code:int = BadMsgNotification;
""" """
__log__.debug('Handling bad message') __log__.debug('Handling bad message')
bad_msg = reader.tgread_object()
if bad_msg.error_code in (16, 17): if bad_msg.error_code in (16, 17):
# Sent msg_id too low or too high (respectively). # Sent msg_id too low or too high (respectively).
# Use the current msg_id to determine the right time offset. # Use the current msg_id to determine the right time offset.
@ -502,7 +489,7 @@ class MTProtoSender:
# Messages are to be re-sent once we've corrected the issue # Messages are to be re-sent once we've corrected the issue
await self._send_queue.put(self._pending_messages[bad_msg.bad_msg_id]) await self._send_queue.put(self._pending_messages[bad_msg.bad_msg_id])
async def _handle_detailed_info(self, message, reader): async def _handle_detailed_info(self, message, detailed_info):
""" """
Updates the current status with the received detailed information: Updates the current status with the received detailed information:
@ -511,9 +498,9 @@ class MTProtoSender:
""" """
# TODO https://goo.gl/VvpCC6 # TODO https://goo.gl/VvpCC6
__log__.debug('Handling detailed info') __log__.debug('Handling detailed info')
self._pending_ack.add(reader.tgread_object().answer_msg_id) self._pending_ack.add(detailed_info.answer_msg_id)
async def _handle_new_detailed_info(self, message, reader): async def _handle_new_detailed_info(self, message, new_detailed_info):
""" """
Updates the current status with the received detailed information: Updates the current status with the received detailed information:
@ -522,9 +509,9 @@ class MTProtoSender:
""" """
# TODO https://goo.gl/G7DPsR # TODO https://goo.gl/G7DPsR
__log__.debug('Handling new detailed info') __log__.debug('Handling new detailed info')
self._pending_ack.add(reader.tgread_object().answer_msg_id) self._pending_ack.add(new_detailed_info.answer_msg_id)
async def _handle_new_session_created(self, message, reader): async def _handle_new_session_created(self, message, new_session):
""" """
Updates the current status with the received session information: Updates the current status with the received session information:
@ -533,7 +520,7 @@ class MTProtoSender:
""" """
# TODO https://goo.gl/LMyN7A # TODO https://goo.gl/LMyN7A
__log__.debug('Handling new session created') __log__.debug('Handling new session created')
self.state.salt = reader.tgread_object().server_salt self.state.salt = new_session.server_salt
def _clean_containers(self, msg_ids): def _clean_containers(self, msg_ids):
""" """
@ -552,7 +539,7 @@ class MTProtoSender:
del self._pending_messages[message.msg_id] del self._pending_messages[message.msg_id]
break break
async def _handle_ack(self, message, reader): async def _handle_ack(self, message, ack):
""" """
Handles a server acknowledge about our messages. Normally Handles a server acknowledge about our messages. Normally
these can be ignored except in the case of ``auth.logOut``: these can be ignored except in the case of ``auth.logOut``:
@ -568,7 +555,6 @@ class MTProtoSender:
messages are acknowledged. messages are acknowledged.
""" """
__log__.debug('Handling acknowledge') __log__.debug('Handling acknowledge')
ack = reader.tgread_object()
if self._pending_containers: if self._pending_containers:
self._clean_containers(ack.msg_ids) self._clean_containers(ack.msg_ids)
@ -578,7 +564,7 @@ class MTProtoSender:
del self._pending_messages[msg_id] del self._pending_messages[msg_id]
msg.future.set_result(True) msg.future.set_result(True)
async def _handle_future_salts(self, message, reader): async def _handle_future_salts(self, message, salts):
""" """
Handles future salt results, which don't come inside a Handles future salt results, which don't come inside a
``rpc_result`` but are still sent through a request: ``rpc_result`` but are still sent through a request:
@ -589,7 +575,6 @@ class MTProtoSender:
# TODO save these salts and automatically adjust to the # TODO save these salts and automatically adjust to the
# correct one whenever the salt in use expires. # correct one whenever the salt in use expires.
__log__.debug('Handling future salts') __log__.debug('Handling future salts')
salts = reader.tgread_object()
msg = self._pending_messages.pop(message.msg_id, None) msg = self._pending_messages.pop(message.msg_id, None)
if msg: if msg:
msg.future.set_result(salts) msg.future.set_result(salts)

View File

@ -6,7 +6,7 @@ from hashlib import sha256
from ..crypto import AES from ..crypto import AES
from ..errors import SecurityError, BrokenAuthKeyError from ..errors import SecurityError, BrokenAuthKeyError
from ..extensions import BinaryReader from ..extensions import BinaryReader
from ..tl import TLMessage from ..tl.core import TLMessage
class MTProtoState: class MTProtoState:

View File

@ -1,4 +1 @@
from .tlobject import TLObject from .tlobject import TLObject
from .gzip_packed import GzipPacked
from .tl_message import TLMessage
from .message_container import MessageContainer

View File

@ -0,0 +1,26 @@
"""
This module holds core "special" types, which are more convenient ways
to do stuff in a `telethon.network.mtprotosender.MTProtoSender` instance.
Only special cases are gzip-packed data, the response message (not a
client message), the message container which references these messages
and would otherwise conflict with the rest, and finally the RpcResult:
rpc_result#f35c6d01 req_msg_id:long result:bytes = RpcResult;
Three things to note with this definition:
1. The constructor ID is actually ``42d36c2c``.
2. Those bytes are not read like the rest of bytes (length + payload).
They are actually the raw bytes of another object, which can't be
read directly because it depends on per-request information (since
some can return ``Vector<int>`` and ``Vector<long>``).
3. Those bytes may be gzipped data, which needs to be treated early.
"""
from .tlmessage import TLMessage
from .gzippacked import GzipPacked
from .messagecontainer import MessageContainer
from .rpcresult import RpcResult
core_objects = {x.CONSTRUCTOR_ID: x for x in (
GzipPacked, MessageContainer, RpcResult
)}

View File

@ -1,7 +1,7 @@
import gzip import gzip
import struct import struct
from . import TLObject from .. import TLObject
class GzipPacked(TLObject): class GzipPacked(TLObject):
@ -36,3 +36,7 @@ class GzipPacked(TLObject):
def read(reader): def read(reader):
assert reader.read_int(signed=False) == GzipPacked.CONSTRUCTOR_ID assert reader.read_int(signed=False) == GzipPacked.CONSTRUCTOR_ID
return gzip.decompress(reader.tgread_bytes()) return gzip.decompress(reader.tgread_bytes())
@classmethod
def from_reader(cls, reader):
return GzipPacked(gzip.decompress(reader.tgread_bytes()))

View File

@ -1,7 +1,7 @@
import struct import struct
from . import TLObject from ..tlobject import TLObject
from .tl_message import TLMessage from .tlmessage import TLMessage
class MessageContainer(TLObject): class MessageContainer(TLObject):
@ -42,3 +42,12 @@ class MessageContainer(TLObject):
def stringify(self): def stringify(self):
return TLObject.pretty_format(self, indent=0) return TLObject.pretty_format(self, indent=0)
@classmethod
def from_reader(cls, reader):
# This assumes that .read_* calls are done in the order they appear
return MessageContainer([TLMessage(
msg_id=reader.read_long(),
seq_no=reader.read_int(),
body=reader.read(reader.read_int())
) for _ in range(reader.read_int())])

View File

@ -0,0 +1,23 @@
from .gzippacked import GzipPacked
from ..types import RpcError
class RpcResult:
CONSTRUCTOR_ID = 0xf35c6d01
def __init__(self, req_msg_id, body, error):
self.req_msg_id = req_msg_id
self.body = body
self.error = error
@classmethod
def from_reader(cls, reader):
msg_id = reader.read_long()
inner_code = reader.read_int(signed=False)
if inner_code == RpcError.CONSTRUCTOR_ID:
return RpcResult(msg_id, None, RpcError.from_reader(reader))
if inner_code == GzipPacked.CONSTRUCTOR_ID:
return RpcResult(msg_id, GzipPacked.from_reader(reader).data, None)
reader.seek(-4)
return RpcResult(msg_id, reader.read(), None)

View File

@ -1,8 +1,9 @@
import asyncio import asyncio
import struct import struct
from . import TLObject, GzipPacked from .gzippacked import GzipPacked
from ..tl.functions import InvokeAfterMsgRequest from .. import TLObject
from ..functions import InvokeAfterMsgRequest
class TLMessage(TLObject): class TLMessage(TLObject):

View File

@ -49,7 +49,7 @@ new_session_created#9ec20908 first_msg_id:long unique_id:long server_salt:long =
//message msg_id:long seqno:int bytes:int body:bytes = Message; //message msg_id:long seqno:int bytes:int body:bytes = Message;
//msg_copy#e06046b2 orig_message:Message = MessageCopy; //msg_copy#e06046b2 orig_message:Message = MessageCopy;
gzip_packed#3072cfa1 packed_data:bytes = Object; //gzip_packed#3072cfa1 packed_data:bytes = Object;
msgs_ack#62d6b459 msg_ids:Vector<long> = MsgsAck; msgs_ack#62d6b459 msg_ids:Vector<long> = MsgsAck;