Create a self-contained MTProtoState

This frees us from using entire Session objects in something
that's supposed to just send and receive items from the net.
This commit is contained in:
Lonami Exo
2018-06-09 11:34:01 +02:00
parent cc5753137c
commit adfe861e9f
5 changed files with 226 additions and 146 deletions

View File

@@ -3,13 +3,13 @@ import logging
from . import MTProtoPlainSender, authenticator
from .connection import ConnectionTcpFull
from .. import helpers, utils
from .. import utils
from ..errors import (
BadMessageError, TypeNotFoundError, BrokenAuthKeyError, SecurityError,
rpc_message_to_error
)
from ..extensions import BinaryReader
from ..tl import TLMessage, MessageContainer, GzipPacked
from ..tl import MessageContainer, GzipPacked
from ..tl.functions.auth import LogOutRequest
from ..tl.types import (
MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts,
@@ -39,8 +39,8 @@ class MTProtoSender:
A new authorization key will be generated on connection if no other
key exists yet.
"""
def __init__(self, session, retries=5):
self.session = session
def __init__(self, state, retries=5):
self.state = state
self._connection = ConnectionTcpFull()
self._ip = None
self._port = None
@@ -171,21 +171,17 @@ class MTProtoSender:
# a `Future` that you need to further ``await`` instead of the
# currently double ``await (await send())``?
if utils.is_list_like(request):
if not ordered:
# False-y values must be None to do after_id = ordered and ...
ordered = None
result = []
after_id = None
after = None
for r in request:
message = TLMessage(self.session, r, after_id=after_id)
message = self.state.create_message(r, after=after)
self._pending_messages[message.msg_id] = message
after_id = ordered and message.msg_id
await self._send_queue.put(message)
result.append(message.future)
after = ordered and message
return result
else:
message = TLMessage(self.session, request)
message = self.state.create_message(request)
self._pending_messages[message.msg_id] = message
await self._send_queue.put(message)
return message.future
@@ -215,13 +211,13 @@ class MTProtoSender:
raise _last_error
__log__.debug('Connection success!')
if self.session.auth_key is None:
if self.state.auth_key is None:
_last_error = SecurityError()
plain = MTProtoPlainSender(self._connection)
for retry in range(1, self._retries + 1):
try:
__log__.debug('New auth_key attempt {}...'.format(retry))
self.session.auth_key, self.session.time_offset =\
self.state.auth_key, self.state.time_offset =\
await authenticator.do_authentication(plain)
except (SecurityError, AssertionError) as e:
_last_error = e
@@ -268,13 +264,14 @@ class MTProtoSender:
"""
while self._user_connected and not self._reconnecting:
if self._pending_ack:
await self._send_queue.put(TLMessage(
self.session, MsgsAck(list(self._pending_ack))))
await self._send_queue.put(self.state.create_message(
MsgsAck(list(self._pending_ack))
))
self._pending_ack.clear()
messages = await self._send_queue.get()
if isinstance(messages, list):
message = TLMessage(self.session, MessageContainer(messages))
message = self.state.create_message(MessageContainer(messages))
self._pending_messages[message.msg_id] = message
self._pending_containers.append(message)
else:
@@ -283,7 +280,7 @@ class MTProtoSender:
__log__.debug('Packing {} outgoing message(s)...'
.format(len(messages)))
body = helpers.pack_message(self.session, message)
body = self.state.pack_message(message)
while not any(m.future.cancelled() for m in messages):
try:
@@ -333,8 +330,7 @@ class MTProtoSender:
# TODO Check salt, session_id and sequence_number
__log__.debug('Decoding packet of {} bytes...'.format(len(body)))
try:
message, remote_msg_id, remote_seq =\
helpers.unpack_message(self.session, body)
message = self.state.unpack_message(body)
except (BrokenAuthKeyError, BufferError) as e:
# The authorization key may be broken if a message was
# sent malformed, or if the authkey truly is corrupted.
@@ -346,7 +342,7 @@ class MTProtoSender:
# TODO Is it possible to detect malformed messages vs
# an actually broken authkey?
__log__.warning('Broken authorization key?: {}'.format(e))
self.session.auth_key = None
self.state.auth_key = None
asyncio.ensure_future(self._reconnect())
break
except SecurityError as e:
@@ -357,28 +353,27 @@ class MTProtoSender:
continue
else:
try:
with BinaryReader(message) as reader:
await self._process_message(
remote_msg_id, remote_seq, reader)
with BinaryReader(message.body) as reader:
await self._process_message(message, reader)
except TypeNotFoundError as e:
__log__.warning('Could not decode received message: {}, '
'raw bytes: {!r}'.format(e, message))
# Response Handlers
async def _process_message(self, msg_id, seq, reader):
async def _process_message(self, message, reader):
"""
Adds the given message to the list of messages that must be
acknowledged and dispatches control to different ``_handle_*``
method based on its type.
"""
self._pending_ack.add(msg_id)
self._pending_ack.add(message.msg_id)
code = reader.read_int(signed=False)
reader.seek(-4)
handler = self._handlers.get(code, self._handle_update)
await handler(msg_id, seq, reader)
await handler(message, reader)
async def _handle_rpc_result(self, msg_id, seq, reader):
async def _handle_rpc_result(self, message, reader):
"""
Handles the result for Remote Procedure Calls:
@@ -395,19 +390,14 @@ class MTProtoSender:
__log__.debug('Handling RPC result for message {}'.format(message_id))
message = self._pending_messages.pop(message_id, None)
if inner_code == 0x2144ca19: # RPC Error
# TODO Report errors if possible/enabled
reader.seek(4)
if self.session.report_errors and message:
error = rpc_message_to_error(
reader.read_int(), reader.tgread_string(),
report_method=type(message.request).CONSTRUCTOR_ID
)
else:
error = rpc_message_to_error(
reader.read_int(), reader.tgread_string()
)
error = rpc_message_to_error(reader.read_int(),
reader.tgread_string())
await self._send_queue.put(
TLMessage(self.session, MsgsAck([msg_id])))
await self._send_queue.put(self.state.create_message(
MsgsAck([message.msg_id])
))
if not message.future.cancelled():
message.future.set_exception(error)
@@ -419,7 +409,7 @@ class MTProtoSender:
else:
result = message.request.read_result(reader)
self.session.process_entities(result)
# TODO Process entities
if not message.future.cancelled():
message.future.set_result(result)
return
@@ -428,19 +418,18 @@ class MTProtoSender:
__log__.info('Received response without parent request: {}'
.format(reader.tgread_object()))
async def _handle_container(self, msg_id, seq, reader):
async def _handle_container(self, message, reader):
"""
Processes the inner messages of a container with many of them:
msg_container#73f1f8dc messages:vector<%Message> = MessageContainer;
"""
__log__.debug('Handling container')
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
next_position = reader.tell_position() + inner_len
await self._process_message(inner_msg_id, seq, reader)
reader.set_position(next_position) # Ensure reading correctly
for inner_message in MessageContainer.iter_read(reader):
with BinaryReader(inner_message.body) as inner_reader:
await self._process_message(inner_message, inner_reader)
async def _handle_gzip_packed(self, msg_id, seq, reader):
async def _handle_gzip_packed(self, message, reader):
"""
Unpacks the data from a gzipped object and processes it:
@@ -448,16 +437,16 @@ class MTProtoSender:
"""
__log__.debug('Handling gzipped data')
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
await self._process_message(msg_id, seq, compressed_reader)
await self._process_message(message, compressed_reader)
async def _handle_update(self, msg_id, seq, reader):
async def _handle_update(self, message, reader):
obj = reader.tgread_object()
__log__.debug('Handling update {}'.format(obj.__class__.__name__))
# TODO Further handling of the update
self.session.process_entities(obj)
# TODO Process entities
async def _handle_pong(self, msg_id, seq, reader):
async def _handle_pong(self, message, reader):
"""
Handles pong results, which don't come inside a ``rpc_result``
but are still sent through a request:
@@ -470,7 +459,7 @@ class MTProtoSender:
if message:
message.future.set_result(pong)
async def _handle_bad_server_salt(self, msg_id, seq, reader):
async def _handle_bad_server_salt(self, message, reader):
"""
Corrects the currently used server salt to use the right value
before enqueuing the rejected message to be re-sent:
@@ -480,11 +469,10 @@ class MTProtoSender:
"""
__log__.debug('Handling bad salt')
bad_salt = reader.tgread_object()
self.session.salt = bad_salt.new_server_salt
self.session.save()
self.state.salt = bad_salt.new_server_salt
await self._send_queue.put(self._pending_messages[bad_salt.bad_msg_id])
async def _handle_bad_notification(self, msg_id, seq, reader):
async def _handle_bad_notification(self, message, reader):
"""
Adjusts the current state to be correct based on the
received bad message notification whenever possible:
@@ -497,14 +485,14 @@ class MTProtoSender:
if bad_msg.error_code in (16, 17):
# Sent msg_id too low or too high (respectively).
# Use the current msg_id to determine the right time offset.
self.session.update_time_offset(correct_msg_id=msg_id)
self.state.update_time_offset(correct_msg_id=message.msg_id)
elif bad_msg.error_code == 32:
# msg_seqno too low, so just pump it up by some "large" amount
# TODO A better fix would be to start with a new fresh session ID
self.session.sequence += 64
self.state._sequence += 64
elif bad_msg.error_code == 33:
# msg_seqno too high never seems to happen but just in case
self.session.sequence -= 16
self.state._sequence -= 16
else:
msg = self._pending_messages.pop(bad_msg.bad_msg_id, None)
if msg:
@@ -514,7 +502,7 @@ class MTProtoSender:
# Messages are to be re-sent once we've corrected the issue
await self._send_queue.put(self._pending_messages[bad_msg.bad_msg_id])
async def _handle_detailed_info(self, msg_id, seq, reader):
async def _handle_detailed_info(self, message, reader):
"""
Updates the current status with the received detailed information:
@@ -525,7 +513,7 @@ class MTProtoSender:
__log__.debug('Handling detailed info')
self._pending_ack.add(reader.tgread_object().answer_msg_id)
async def _handle_new_detailed_info(self, msg_id, seq, reader):
async def _handle_new_detailed_info(self, message, reader):
"""
Updates the current status with the received detailed information:
@@ -536,7 +524,7 @@ class MTProtoSender:
__log__.debug('Handling new detailed info')
self._pending_ack.add(reader.tgread_object().answer_msg_id)
async def _handle_new_session_created(self, msg_id, seq, reader):
async def _handle_new_session_created(self, message, reader):
"""
Updates the current status with the received session information:
@@ -545,7 +533,7 @@ class MTProtoSender:
"""
# TODO https://goo.gl/LMyN7A
__log__.debug('Handling new session created')
self.session.salt = reader.tgread_object().server_salt
self.state.salt = reader.tgread_object().server_salt
def _clean_containers(self, msg_ids):
"""
@@ -564,7 +552,7 @@ class MTProtoSender:
del self._pending_messages[message.msg_id]
break
async def _handle_ack(self, msg_id, seq, reader):
async def _handle_ack(self, message, reader):
"""
Handles a server acknowledge about our messages. Normally
these can be ignored except in the case of ``auth.logOut``:
@@ -590,7 +578,7 @@ class MTProtoSender:
del self._pending_messages[msg_id]
msg.future.set_result(True)
async def _handle_future_salts(self, msg_id, seq, reader):
async def _handle_future_salts(self, message, reader):
"""
Handles future salt results, which don't come inside a
``rpc_result`` but are still sent through a request:
@@ -602,7 +590,7 @@ class MTProtoSender:
# correct one whenever the salt in use expires.
__log__.debug('Handling future salts')
salts = reader.tgread_object()
msg = self._pending_messages.pop(msg_id, None)
msg = self._pending_messages.pop(message.msg_id, None)
if msg:
msg.future.set_result(salts)