Telethon/telethon/network/mtprotosender.py
2018-06-06 21:42:48 +02:00

206 lines
7.8 KiB
Python

import asyncio
import logging
from .connection import ConnectionTcpFull
from .. import helpers
from ..errors import rpc_message_to_error
from ..extensions import BinaryReader
from ..tl import TLMessage, MessageContainer, GzipPacked
from ..tl.types import (
MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts,
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo
)
__log__ = logging.getLogger(__name__)
# TODO Create some kind of "ReconnectionPolicy" that allows specifying
# what should be done in case of some errors, with some sane defaults.
# For instance, should all messages be set with an error upon network
# loss? Should we try reconnecting forever? A certain amount of times?
# A timeout? What about recoverable errors, like connection reset?
class MTProtoSender:
def __init__(self, session):
self.session = session
self._connection = ConnectionTcpFull()
self._user_connected = False
# Send and receive calls must be atomic
self._send_lock = asyncio.Lock()
self._recv_lock = asyncio.Lock()
# We need to join the loops upon disconnection
self._send_loop_handle = None
self._recv_loop_handle = None
# Sending something shouldn't block
self._send_queue = asyncio.Queue()
# Telegram responds to messages out of order. Keep
# {id: Message} to set their Future result upon arrival.
self._pending_messages = {}
# We need to acknowledge every response from Telegram
self._pending_ack = set()
# Jump table from response ID to method that handles it
self._handlers = {
0xf35c6d01: self._handle_rpc_result,
MessageContainer.CONSTRUCTOR_ID: self._handle_container,
GzipPacked.CONSTRUCTOR_ID: self._handle_gzip_packed,
Pong.CONSTRUCTOR_ID: self._handle_pong,
BadServerSalt.CONSTRUCTOR_ID: self._handle_bad_server_salt,
BadMsgNotification.CONSTRUCTOR_ID: self._handle_bad_notification,
MsgDetailedInfo.CONSTRUCTOR_ID: self._handle_detailed_info,
MsgNewDetailedInfo.CONSTRUCTOR_ID: self._handle_new_detailed_info,
NewSessionCreated.CONSTRUCTOR_ID: self._handle_new_session_created,
MsgsAck.CONSTRUCTOR_ID: self._handle_ack,
FutureSalts.CONSTRUCTOR_ID: self._handle_future_salts
}
# Public API
async def connect(self, ip, port):
async with self._send_lock:
await self._connection.connect(ip, port)
self._user_connected = True
self._send_loop_handle = asyncio.ensure_future(self._send_loop())
self._recv_loop_handle = asyncio.ensure_future(self._recv_loop())
async def disconnect(self):
self._user_connected = False
try:
async with self._send_lock:
await self._connection.close()
except:
__log__.exception('Ignoring exception upon disconnection')
finally:
self._send_loop_handle.cancel()
self._recv_loop_handle.cancel()
async def send(self, request):
# TODO Should the asyncio.Future creation belong here?
request.result = asyncio.Future()
message = TLMessage(self.session, request)
self._pending_messages[message.msg_id] = message
await self._send_queue.put(message)
# Loops
async def _send_loop(self):
while self._user_connected:
# TODO If there's more than one item, send them all at once
body = helpers.pack_message(
self.session, await self._send_queue.get())
# TODO Handle exceptions
async with self._send_lock:
await self._connection.send(body)
async def _recv_loop(self):
while self._user_connected:
# TODO Handle exceptions
async with self._recv_lock:
body = await self._connection.recv()
# TODO Check salt, session_id and sequence_number
message, remote_msg_id, remote_seq = helpers.unpack_message(
self.session, body)
with BinaryReader(message) as reader:
await self._process_message(remote_msg_id, remote_seq, reader)
# Response Handlers
async def _process_message(self, msg_id, seq, reader):
self._pending_ack.add(msg_id)
code = reader.read_int(signed=False)
reader.seek(-4)
handler = self._handlers.get(code)
if handler:
await handler(msg_id, seq, reader)
else:
pass # TODO Process updates and their entities
async def _handle_rpc_result(self, msg_id, seq, reader):
# TODO Don't make this a special case
reader.read_int(signed=False) # code
message_id = reader.read_long()
inner_code = reader.read_int(signed=False)
reader.seek(-4)
message = self._pending_messages.pop(message_id)
if inner_code == 0x2144ca19: # RPC Error
reader.seek(4)
if self.session.report_errors and message:
error = rpc_message_to_error(
reader.read_int(), reader.tgread_string(),
report_method=type(message.request).CONSTRUCTOR_ID
)
else:
error = rpc_message_to_error(
reader.read_int(), reader.tgread_string()
)
# TODO Acknowledge that we received the error request_id
# TODO Set message.request exception
elif message:
# TODO Make on_response result.set_result() instead replacing it
if inner_code == GzipPacked.CONSTRUCTOR_ID:
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
message.on_response(compressed_reader)
else:
message.on_response(reader)
# TODO Process possible entities
# TODO Try reading an object
async def _handle_container(self, msg_id, seq, reader):
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
next_position = reader.tell_position() + inner_len
await self._process_message(inner_msg_id, seq, reader)
reader.set_position(next_position) # Ensure reading correctly
async def _handle_gzip_packed(self, msg_id, seq, reader):
raise NotImplementedError
async def _handle_pong(self, msg_id, seq, reader):
raise NotImplementedError
async def _handle_bad_server_salt(self, msg_id, seq, reader):
bad_salt = reader.tgread_object()
self.session.salt = bad_salt.new_server_salt
self.session.save()
# "the bad_server_salt response is received with the
# correct salt, and the message is to be re-sent with it"
await self._send_queue.put(self._pending_messages[bad_salt.bad_msg_id])
async def _handle_bad_notification(self, msg_id, seq, reader):
raise NotImplementedError
async def _handle_detailed_info(self, msg_id, seq, reader):
raise NotImplementedError
async def _handle_new_detailed_info(self, msg_id, seq, reader):
raise NotImplementedError
async def _handle_new_session_created(self, msg_id, seq, reader):
# TODO https://goo.gl/LMyN7A
new_session = reader.tgread_object()
self.session.salt = new_session.server_salt
async def _handle_ack(self, msg_id, seq, reader):
# Ignore every ack request *unless* when logging out, when it's
# when it seems to only make sense. We also need to set a non-None
# result since Telegram doesn't send the response for these.
for msg_id in reader.tgread_object().msg_ids:
# TODO pop msg_id if of type LogOutRequest, and confirm it
pass
return True
async def _handle_future_salts(self, msg_id, seq, reader):
raise NotImplementedError