diff --git a/telethon/errors/__init__.py b/telethon/errors/__init__.py index 8b4e9f88..ca050de9 100644 --- a/telethon/errors/__init__.py +++ b/telethon/errors/__init__.py @@ -40,49 +40,49 @@ def report_error(code, message, report_method): "We really don't want to crash when just reporting an error" -def rpc_message_to_error(code, message, report_method=None): +def rpc_message_to_error(rpc_error, report_method=None): """ Converts a Telegram's RPC Error to a Python error. - :param code: the integer code of the error (like 400). - :param message: the message representing the error. + :param rpc_error: the RpcError instance. :param report_method: if present, the ID of the method that caused it. :return: the RPCError as a Python exception that represents this error. """ if report_method is not None: Thread( target=report_error, - args=(code, message, report_method) + args=(rpc_error.error_code, rpc_error.error_message, report_method) ).start() # Try to get the error by direct look-up, otherwise regex # TODO Maybe regexes could live in a separate dictionary? - cls = rpc_errors_all.get(message, None) + cls = rpc_errors_all.get(rpc_error.error_message, None) if cls: return cls() for msg_regex, cls in rpc_errors_all.items(): - m = re.match(msg_regex, message) + m = re.match(msg_regex, rpc_error.error_message) if m: capture = int(m.group(1)) if m.groups() else None return cls(capture=capture) - if code == 400: - return BadRequestError(message) + if rpc_error.error_code == 400: + return BadRequestError(rpc_error.error_message) - if code == 401: - return UnauthorizedError(message) + if rpc_error.error_code == 401: + return UnauthorizedError(rpc_error.error_message) - if code == 403: - return ForbiddenError(message) + if rpc_error.error_code == 403: + return ForbiddenError(rpc_error.error_message) - if code == 404: - return NotFoundError(message) + if rpc_error.error_code == 404: + return NotFoundError(rpc_error.error_message) - if code == 406: - return AuthKeyError(message) + if rpc_error.error_code == 406: + return AuthKeyError(rpc_error.error_message) - if code == 500: - return ServerError(message) + if rpc_error.error_code == 500: + return ServerError(rpc_error.error_message) - return RPCError('{} (code {})'.format(message, code)) + return RPCError('{} (code {})'.format( + rpc_error.error_message, rpc_error.error_code)) diff --git a/telethon/extensions/binary_reader.py b/telethon/extensions/binary_reader.py index ecf7dd1b..e7496d77 100644 --- a/telethon/extensions/binary_reader.py +++ b/telethon/extensions/binary_reader.py @@ -8,6 +8,7 @@ from struct import unpack from ..errors import TypeNotFoundError from ..tl.all_tlobjects import tlobjects +from ..tl.core import core_objects class BinaryReader: @@ -136,9 +137,11 @@ class BinaryReader: elif value == 0x1cb5c415: # Vector return [self.tgread_object() for _ in range(self.read_int())] - # If there was still no luck, give up - self.seek(-4) # Go back - raise TypeNotFoundError(constructor_id) + clazz = core_objects.get(constructor_id, None) + if clazz is None: + # If there was still no luck, give up + self.seek(-4) # Go back + raise TypeNotFoundError(constructor_id) return clazz.from_reader(self) diff --git a/telethon/network/mtprotosender.py b/telethon/network/mtprotosender.py index 00a9c038..28369b00 100644 --- a/telethon/network/mtprotosender.py +++ b/telethon/network/mtprotosender.py @@ -9,7 +9,7 @@ from ..errors import ( rpc_message_to_error ) from ..extensions import BinaryReader -from ..tl import MessageContainer, GzipPacked +from ..tl.core import RpcResult, MessageContainer, GzipPacked from ..tl.functions.auth import LogOutRequest from ..tl.types import ( MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts, @@ -80,7 +80,7 @@ class MTProtoSender: # Jump table from response ID to method that handles it self._handlers = { - 0xf35c6d01: self._handle_rpc_result, + RpcResult.CONSTRUCTOR_ID: self._handle_rpc_result, MessageContainer.CONSTRUCTOR_ID: self._handle_container, GzipPacked.CONSTRUCTOR_ID: self._handle_gzip_packed, Pong.CONSTRUCTOR_ID: self._handle_pong, @@ -354,26 +354,26 @@ class MTProtoSender: else: try: with BinaryReader(message.body) as reader: - await self._process_message(message, reader) + obj = reader.tgread_object() except TypeNotFoundError as e: __log__.warning('Could not decode received message: {}, ' 'raw bytes: {!r}'.format(e, message)) + else: + await self._process_message(message, obj) # Response Handlers - async def _process_message(self, message, reader): + async def _process_message(self, message, obj): """ Adds the given message to the list of messages that must be acknowledged and dispatches control to different ``_handle_*`` method based on its type. """ self._pending_ack.add(message.msg_id) - code = reader.read_int(signed=False) - reader.seek(-4) - handler = self._handlers.get(code, self._handle_update) - await handler(message, reader) + handler = self._handlers.get(obj.CONSTRUCTOR_ID, self._handle_update) + await handler(message, obj) - async def _handle_rpc_result(self, message, reader): + async def _handle_rpc_result(self, message, rpc_result): """ Handles the result for Remote Procedure Calls: @@ -381,20 +381,13 @@ class MTProtoSender: This is where the future results for sent requests are set. """ - # TODO Don't make this a special cased object - reader.read_int(signed=False) # code - message_id = reader.read_long() - inner_code = reader.read_int(signed=False) - reader.seek(-4) + message = self._pending_messages.pop(rpc_result.req_msg_id, None) + __log__.debug('Handling RPC result for message {}' + .format(rpc_result.req_msg_id)) - __log__.debug('Handling RPC result for message {}'.format(message_id)) - message = self._pending_messages.pop(message_id, None) - if inner_code == 0x2144ca19: # RPC Error + if rpc_result.error: # TODO Report errors if possible/enabled - reader.seek(4) - error = rpc_message_to_error(reader.read_int(), - reader.tgread_string()) - + error = rpc_message_to_error(rpc_result.error) await self._send_queue.put(self.state.create_message( MsgsAck([message.msg_id]) )) @@ -403,10 +396,7 @@ class MTProtoSender: message.future.set_exception(error) return elif message: - if inner_code == GzipPacked.CONSTRUCTOR_ID: - with BinaryReader(GzipPacked.read(reader)) as compressed_reader: - result = message.request.read_result(compressed_reader) - else: + with BinaryReader(rpc_result.body) as reader: result = message.request.read_result(reader) # TODO Process entities @@ -416,37 +406,37 @@ class MTProtoSender: else: # TODO We should not get responses to things we never sent __log__.info('Received response without parent request: {}' - .format(reader.tgread_object())) + .format(rpc_result.body)) - async def _handle_container(self, message, reader): + async def _handle_container(self, message, container): """ Processes the inner messages of a container with many of them: msg_container#73f1f8dc messages:vector<%Message> = MessageContainer; """ __log__.debug('Handling container') - for inner_message in MessageContainer.iter_read(reader): - with BinaryReader(inner_message.body) as inner_reader: - await self._process_message(inner_message, inner_reader) + for inner_message in container.messages: + with BinaryReader(inner_message.body) as reader: + inner_obj = reader.tgread_object() + await self._process_message(inner_message, inner_obj) - async def _handle_gzip_packed(self, message, reader): + async def _handle_gzip_packed(self, message, gzip_packed): """ Unpacks the data from a gzipped object and processes it: gzip_packed#3072cfa1 packed_data:bytes = Object; """ __log__.debug('Handling gzipped data') - with BinaryReader(GzipPacked.read(reader)) as compressed_reader: - await self._process_message(message, compressed_reader) + with BinaryReader(gzip_packed.data) as reader: + await self._process_message(message, reader.tgread_object()) - async def _handle_update(self, message, reader): - obj = reader.tgread_object() - __log__.debug('Handling update {}'.format(obj.__class__.__name__)) + async def _handle_update(self, message, update): + __log__.debug('Handling update {}'.format(update.__class__.__name__)) # TODO Further handling of the update # TODO Process entities - async def _handle_pong(self, message, reader): + async def _handle_pong(self, message, pong): """ Handles pong results, which don't come inside a ``rpc_result`` but are still sent through a request: @@ -454,12 +444,11 @@ class MTProtoSender: pong#347773c5 msg_id:long ping_id:long = Pong; """ __log__.debug('Handling pong') - pong = reader.tgread_object() message = self._pending_messages.pop(pong.msg_id, None) if message: message.future.set_result(pong) - async def _handle_bad_server_salt(self, message, reader): + async def _handle_bad_server_salt(self, message, bad_salt): """ Corrects the currently used server salt to use the right value before enqueuing the rejected message to be re-sent: @@ -468,11 +457,10 @@ class MTProtoSender: error_code:int new_server_salt:long = BadMsgNotification; """ __log__.debug('Handling bad salt') - bad_salt = reader.tgread_object() self.state.salt = bad_salt.new_server_salt await self._send_queue.put(self._pending_messages[bad_salt.bad_msg_id]) - async def _handle_bad_notification(self, message, reader): + async def _handle_bad_notification(self, message, bad_msg): """ Adjusts the current state to be correct based on the received bad message notification whenever possible: @@ -481,7 +469,6 @@ class MTProtoSender: error_code:int = BadMsgNotification; """ __log__.debug('Handling bad message') - bad_msg = reader.tgread_object() if bad_msg.error_code in (16, 17): # Sent msg_id too low or too high (respectively). # Use the current msg_id to determine the right time offset. @@ -502,7 +489,7 @@ class MTProtoSender: # Messages are to be re-sent once we've corrected the issue await self._send_queue.put(self._pending_messages[bad_msg.bad_msg_id]) - async def _handle_detailed_info(self, message, reader): + async def _handle_detailed_info(self, message, detailed_info): """ Updates the current status with the received detailed information: @@ -511,9 +498,9 @@ class MTProtoSender: """ # TODO https://goo.gl/VvpCC6 __log__.debug('Handling detailed info') - self._pending_ack.add(reader.tgread_object().answer_msg_id) + self._pending_ack.add(detailed_info.answer_msg_id) - async def _handle_new_detailed_info(self, message, reader): + async def _handle_new_detailed_info(self, message, new_detailed_info): """ Updates the current status with the received detailed information: @@ -522,9 +509,9 @@ class MTProtoSender: """ # TODO https://goo.gl/G7DPsR __log__.debug('Handling new detailed info') - self._pending_ack.add(reader.tgread_object().answer_msg_id) + self._pending_ack.add(new_detailed_info.answer_msg_id) - async def _handle_new_session_created(self, message, reader): + async def _handle_new_session_created(self, message, new_session): """ Updates the current status with the received session information: @@ -533,7 +520,7 @@ class MTProtoSender: """ # TODO https://goo.gl/LMyN7A __log__.debug('Handling new session created') - self.state.salt = reader.tgread_object().server_salt + self.state.salt = new_session.server_salt def _clean_containers(self, msg_ids): """ @@ -552,7 +539,7 @@ class MTProtoSender: del self._pending_messages[message.msg_id] break - async def _handle_ack(self, message, reader): + async def _handle_ack(self, message, ack): """ Handles a server acknowledge about our messages. Normally these can be ignored except in the case of ``auth.logOut``: @@ -568,7 +555,6 @@ class MTProtoSender: messages are acknowledged. """ __log__.debug('Handling acknowledge') - ack = reader.tgread_object() if self._pending_containers: self._clean_containers(ack.msg_ids) @@ -578,7 +564,7 @@ class MTProtoSender: del self._pending_messages[msg_id] msg.future.set_result(True) - async def _handle_future_salts(self, message, reader): + async def _handle_future_salts(self, message, salts): """ Handles future salt results, which don't come inside a ``rpc_result`` but are still sent through a request: @@ -589,7 +575,6 @@ class MTProtoSender: # TODO save these salts and automatically adjust to the # correct one whenever the salt in use expires. __log__.debug('Handling future salts') - salts = reader.tgread_object() msg = self._pending_messages.pop(message.msg_id, None) if msg: msg.future.set_result(salts) diff --git a/telethon/network/mtprotostate.py b/telethon/network/mtprotostate.py index 7c37f14b..36a9dde1 100644 --- a/telethon/network/mtprotostate.py +++ b/telethon/network/mtprotostate.py @@ -6,7 +6,7 @@ from hashlib import sha256 from ..crypto import AES from ..errors import SecurityError, BrokenAuthKeyError from ..extensions import BinaryReader -from ..tl import TLMessage +from ..tl.core import TLMessage class MTProtoState: diff --git a/telethon/tl/__init__.py b/telethon/tl/__init__.py index 96c934bb..b2ffbca8 100644 --- a/telethon/tl/__init__.py +++ b/telethon/tl/__init__.py @@ -1,4 +1 @@ from .tlobject import TLObject -from .gzip_packed import GzipPacked -from .tl_message import TLMessage -from .message_container import MessageContainer diff --git a/telethon/tl/core/__init__.py b/telethon/tl/core/__init__.py new file mode 100644 index 00000000..3113196a --- /dev/null +++ b/telethon/tl/core/__init__.py @@ -0,0 +1,26 @@ +""" +This module holds core "special" types, which are more convenient ways +to do stuff in a `telethon.network.mtprotosender.MTProtoSender` instance. + +Only special cases are gzip-packed data, the response message (not a +client message), the message container which references these messages +and would otherwise conflict with the rest, and finally the RpcResult: + + rpc_result#f35c6d01 req_msg_id:long result:bytes = RpcResult; + +Three things to note with this definition: +1. The constructor ID is actually ``42d36c2c``. +2. Those bytes are not read like the rest of bytes (length + payload). + They are actually the raw bytes of another object, which can't be + read directly because it depends on per-request information (since + some can return ``Vector`` and ``Vector``). +3. Those bytes may be gzipped data, which needs to be treated early. +""" +from .tlmessage import TLMessage +from .gzippacked import GzipPacked +from .messagecontainer import MessageContainer +from .rpcresult import RpcResult + +core_objects = {x.CONSTRUCTOR_ID: x for x in ( + GzipPacked, MessageContainer, RpcResult +)} diff --git a/telethon/tl/gzip_packed.py b/telethon/tl/core/gzippacked.py similarity index 89% rename from telethon/tl/gzip_packed.py rename to telethon/tl/core/gzippacked.py index 053acd86..6ec61b49 100644 --- a/telethon/tl/gzip_packed.py +++ b/telethon/tl/core/gzippacked.py @@ -1,7 +1,7 @@ import gzip import struct -from . import TLObject +from .. import TLObject class GzipPacked(TLObject): @@ -36,3 +36,7 @@ class GzipPacked(TLObject): def read(reader): assert reader.read_int(signed=False) == GzipPacked.CONSTRUCTOR_ID return gzip.decompress(reader.tgread_bytes()) + + @classmethod + def from_reader(cls, reader): + return GzipPacked(gzip.decompress(reader.tgread_bytes())) diff --git a/telethon/tl/message_container.py b/telethon/tl/core/messagecontainer.py similarity index 75% rename from telethon/tl/message_container.py rename to telethon/tl/core/messagecontainer.py index acd51bb4..ef1eab1e 100644 --- a/telethon/tl/message_container.py +++ b/telethon/tl/core/messagecontainer.py @@ -1,7 +1,7 @@ import struct -from . import TLObject -from .tl_message import TLMessage +from ..tlobject import TLObject +from .tlmessage import TLMessage class MessageContainer(TLObject): @@ -42,3 +42,12 @@ class MessageContainer(TLObject): def stringify(self): return TLObject.pretty_format(self, indent=0) + + @classmethod + def from_reader(cls, reader): + # This assumes that .read_* calls are done in the order they appear + return MessageContainer([TLMessage( + msg_id=reader.read_long(), + seq_no=reader.read_int(), + body=reader.read(reader.read_int()) + ) for _ in range(reader.read_int())]) diff --git a/telethon/tl/core/rpcresult.py b/telethon/tl/core/rpcresult.py new file mode 100644 index 00000000..08b7a555 --- /dev/null +++ b/telethon/tl/core/rpcresult.py @@ -0,0 +1,23 @@ +from .gzippacked import GzipPacked +from ..types import RpcError + + +class RpcResult: + CONSTRUCTOR_ID = 0xf35c6d01 + + def __init__(self, req_msg_id, body, error): + self.req_msg_id = req_msg_id + self.body = body + self.error = error + + @classmethod + def from_reader(cls, reader): + msg_id = reader.read_long() + inner_code = reader.read_int(signed=False) + if inner_code == RpcError.CONSTRUCTOR_ID: + return RpcResult(msg_id, None, RpcError.from_reader(reader)) + if inner_code == GzipPacked.CONSTRUCTOR_ID: + return RpcResult(msg_id, GzipPacked.from_reader(reader).data, None) + + reader.seek(-4) + return RpcResult(msg_id, reader.read(), None) diff --git a/telethon/tl/tl_message.py b/telethon/tl/core/tlmessage.py similarity index 95% rename from telethon/tl/tl_message.py rename to telethon/tl/core/tlmessage.py index e37902dc..6de84669 100644 --- a/telethon/tl/tl_message.py +++ b/telethon/tl/core/tlmessage.py @@ -1,8 +1,9 @@ import asyncio import struct -from . import TLObject, GzipPacked -from ..tl.functions import InvokeAfterMsgRequest +from .gzippacked import GzipPacked +from .. import TLObject +from ..functions import InvokeAfterMsgRequest class TLMessage(TLObject): diff --git a/telethon_generator/data/mtproto_api.tl b/telethon_generator/data/mtproto_api.tl index aa5e4c97..79fbb40d 100644 --- a/telethon_generator/data/mtproto_api.tl +++ b/telethon_generator/data/mtproto_api.tl @@ -49,7 +49,7 @@ new_session_created#9ec20908 first_msg_id:long unique_id:long server_salt:long = //message msg_id:long seqno:int bytes:int body:bytes = Message; //msg_copy#e06046b2 orig_message:Message = MessageCopy; -gzip_packed#3072cfa1 packed_data:bytes = Object; +//gzip_packed#3072cfa1 packed_data:bytes = Object; msgs_ack#62d6b459 msg_ids:Vector = MsgsAck;