Telethon/telethon/network/mtprotosender.py
2018-06-07 10:30:20 +02:00

235 lines
8.9 KiB
Python

import asyncio
import logging
from .connection import ConnectionTcpFull
from .. import helpers
from ..errors import rpc_message_to_error
from ..extensions import BinaryReader
from ..tl import TLMessage, MessageContainer, GzipPacked
from ..tl.types import (
MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts,
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo
)
__log__ = logging.getLogger(__name__)
# TODO Create some kind of "ReconnectionPolicy" that allows specifying
# what should be done in case of some errors, with some sane defaults.
# For instance, should all messages be set with an error upon network
# loss? Should we try reconnecting forever? A certain amount of times?
# A timeout? What about recoverable errors, like connection reset?
class MTProtoSender:
def __init__(self, session):
self.session = session
self._connection = ConnectionTcpFull()
self._user_connected = False
# Send and receive calls must be atomic
self._send_lock = asyncio.Lock()
self._recv_lock = asyncio.Lock()
# We need to join the loops upon disconnection
self._send_loop_handle = None
self._recv_loop_handle = None
# Sending something shouldn't block
self._send_queue = asyncio.Queue()
# Telegram responds to messages out of order. Keep
# {id: Message} to set their Future result upon arrival.
self._pending_messages = {}
# We need to acknowledge every response from Telegram
self._pending_ack = set()
# Jump table from response ID to method that handles it
self._handlers = {
0xf35c6d01: self._handle_rpc_result,
MessageContainer.CONSTRUCTOR_ID: self._handle_container,
GzipPacked.CONSTRUCTOR_ID: self._handle_gzip_packed,
Pong.CONSTRUCTOR_ID: self._handle_pong,
BadServerSalt.CONSTRUCTOR_ID: self._handle_bad_server_salt,
BadMsgNotification.CONSTRUCTOR_ID: self._handle_bad_notification,
MsgDetailedInfo.CONSTRUCTOR_ID: self._handle_detailed_info,
MsgNewDetailedInfo.CONSTRUCTOR_ID: self._handle_new_detailed_info,
NewSessionCreated.CONSTRUCTOR_ID: self._handle_new_session_created,
MsgsAck.CONSTRUCTOR_ID: self._handle_ack,
FutureSalts.CONSTRUCTOR_ID: self._handle_future_salts
}
# Public API
async def connect(self, ip, port):
async with self._send_lock:
await self._connection.connect(ip, port)
self._user_connected = True
self._send_loop_handle = asyncio.ensure_future(self._send_loop())
self._recv_loop_handle = asyncio.ensure_future(self._recv_loop())
async def disconnect(self):
self._user_connected = False
try:
async with self._send_lock:
await self._connection.close()
except:
__log__.exception('Ignoring exception upon disconnection')
finally:
self._send_loop_handle.cancel()
self._recv_loop_handle.cancel()
async def send(self, request):
"""
This method enqueues the given request to be sent.
The request will be wrapped inside a `TLMessage` until its
response arrives, and the `Future` response of the `TLMessage`
is immediately returned so that one can further ``await`` it:
.. code-block:: python
async def method():
# Sending (enqueued for the send loop)
future = await sender.send(request)
# Receiving (waits for the receive loop to read the result)
result = await future
Designed like this because Telegram may send the response at
any point, and it can send other items while one waits for it.
Once the response for this future arrives, it is set with the
received result, quite similar to how a ``receive()`` call
would otherwise work.
Since the receiving part is "built in" the future, it's
impossible to await receive a result that was never sent.
"""
message = TLMessage(self.session, request)
self._pending_messages[message.msg_id] = message
await self._send_queue.put(message)
return message.future
# Loops
async def _send_loop(self):
while self._user_connected:
# TODO If there's more than one item, send them all at once
body = helpers.pack_message(
self.session, await self._send_queue.get())
# TODO Handle exceptions
async with self._send_lock:
await self._connection.send(body)
async def _recv_loop(self):
while self._user_connected:
# TODO Handle exceptions
async with self._recv_lock:
body = await self._connection.recv()
# TODO Check salt, session_id and sequence_number
message, remote_msg_id, remote_seq = helpers.unpack_message(
self.session, body)
with BinaryReader(message) as reader:
await self._process_message(remote_msg_id, remote_seq, reader)
# Response Handlers
async def _process_message(self, msg_id, seq, reader):
self._pending_ack.add(msg_id)
code = reader.read_int(signed=False)
reader.seek(-4)
handler = self._handlers.get(code)
if handler:
await handler(msg_id, seq, reader)
else:
pass # TODO Process updates and their entities
async def _handle_rpc_result(self, msg_id, seq, reader):
# TODO Don't make this a special case
reader.read_int(signed=False) # code
message_id = reader.read_long()
inner_code = reader.read_int(signed=False)
reader.seek(-4)
message = self._pending_messages.pop(message_id, None)
if inner_code == 0x2144ca19: # RPC Error
reader.seek(4)
if self.session.report_errors and message:
error = rpc_message_to_error(
reader.read_int(), reader.tgread_string(),
report_method=type(message.request).CONSTRUCTOR_ID
)
else:
error = rpc_message_to_error(
reader.read_int(), reader.tgread_string()
)
await self._send_queue.put(
TLMessage(self.session, MsgsAck([msg_id])))
if not message.future.cancelled():
message.future.set_exception(error)
return
elif message:
if inner_code == GzipPacked.CONSTRUCTOR_ID:
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
result = message.request.read_result(compressed_reader)
else:
result = message.request.read_result(reader)
# TODO Process possible entities
if not message.future.cancelled():
message.future.set_result(result)
return
# TODO Try reading an object
async def _handle_container(self, msg_id, seq, reader):
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
next_position = reader.tell_position() + inner_len
await self._process_message(inner_msg_id, seq, reader)
reader.set_position(next_position) # Ensure reading correctly
async def _handle_gzip_packed(self, msg_id, seq, reader):
raise NotImplementedError
async def _handle_pong(self, msg_id, seq, reader):
raise NotImplementedError
async def _handle_bad_server_salt(self, msg_id, seq, reader):
bad_salt = reader.tgread_object()
self.session.salt = bad_salt.new_server_salt
self.session.save()
# "the bad_server_salt response is received with the
# correct salt, and the message is to be re-sent with it"
await self._send_queue.put(self._pending_messages[bad_salt.bad_msg_id])
async def _handle_bad_notification(self, msg_id, seq, reader):
raise NotImplementedError
async def _handle_detailed_info(self, msg_id, seq, reader):
raise NotImplementedError
async def _handle_new_detailed_info(self, msg_id, seq, reader):
raise NotImplementedError
async def _handle_new_session_created(self, msg_id, seq, reader):
# TODO https://goo.gl/LMyN7A
new_session = reader.tgread_object()
self.session.salt = new_session.server_salt
async def _handle_ack(self, msg_id, seq, reader):
# Ignore every ack request *unless* when logging out, when it's
# when it seems to only make sense. We also need to set a non-None
# result since Telegram doesn't send the response for these.
for msg_id in reader.tgread_object().msg_ids:
# TODO pop msg_id if of type LogOutRequest, and confirm it
pass
return True
async def _handle_future_salts(self, msg_id, seq, reader):
raise NotImplementedError