Create and use UpdateState to .process() unhandled TLObjects

This commit is contained in:
Lonami Exo
2017-09-07 18:49:08 +02:00
parent 49e884b005
commit d4f36162cd
4 changed files with 79 additions and 41 deletions

View File

@@ -35,16 +35,6 @@ class MtProtoSender:
# TODO There might be a better way to handle msgs_ack requests
self.logging_out = False
# Every unhandled result gets passed to these callbacks, which
# should be functions accepting a single parameter: a TLObject.
# This should only be Update(s), although it can actually be any type.
#
# The thread from which these callbacks are called can be any.
#
# The creator of the MtProtoSender is responsible for setting this
# to point to the list wherever their callbacks reside.
self.unhandled_callbacks = None
def connect(self):
"""Connects to the server"""
self.connection.connect()
@@ -90,12 +80,15 @@ class MtProtoSender:
del self._need_confirmation[:]
def receive(self):
def receive(self, update_state):
"""Receives a single message from the connected endpoint.
This method returns nothing, and will only affect other parts
of the MtProtoSender such as the updates callback being fired
or a pending request being confirmed.
Any unhandled object (likely updates) will be passed to
update_state.process(TLObject).
"""
# TODO Don't ignore updates
self._logger.debug('Receiving a message...')
@@ -103,8 +96,7 @@ class MtProtoSender:
message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader:
self._process_msg(
remote_msg_id, remote_seq, reader, updates=None)
self._process_msg(remote_msg_id, remote_seq, reader, update_state)
self._logger.debug('Received message.')
@@ -172,7 +164,7 @@ class MtProtoSender:
return message, remote_msg_id, remote_sequence
def _process_msg(self, msg_id, sequence, reader, updates):
def _process_msg(self, msg_id, sequence, reader, state):
"""Processes and handles a Telegram message.
Returns True if the message was handled correctly and doesn't
@@ -193,10 +185,10 @@ class MtProtoSender:
return self._handle_pong(msg_id, sequence, reader)
if code == 0x73f1f8dc: # msg_container
return self._handle_container(msg_id, sequence, reader, updates)
return self._handle_container(msg_id, sequence, reader, state)
if code == 0x3072cfa1: # gzip_packed
return self._handle_gzip_packed(msg_id, sequence, reader, updates)
return self._handle_gzip_packed(msg_id, sequence, reader, state)
if code == 0xedab447b: # bad_server_salt
return self._handle_bad_server_salt(msg_id, sequence, reader)
@@ -221,16 +213,15 @@ class MtProtoSender:
# If the code is not parsed manually then it should be a TLObject.
if code in tlobjects:
result = reader.tgread_object()
if self.unhandled_callbacks:
self._logger.debug(
'Passing TLObject to callbacks %s', repr(result)
)
for callback in self.unhandled_callbacks:
callback(result)
else:
if state is None:
self._logger.debug(
'Ignoring unhandled TLObject %s', repr(result)
)
else:
self._logger.debug(
'Processing TLObject %s', repr(result)
)
state.process(result)
return True
@@ -261,7 +252,7 @@ class MtProtoSender:
return True
def _handle_container(self, msg_id, sequence, reader, updates):
def _handle_container(self, msg_id, sequence, reader, state):
self._logger.debug('Handling container')
reader.read_int(signed=False) # code
size = reader.read_int()
@@ -274,8 +265,7 @@ class MtProtoSender:
# Note that this code is IMPORTANT for skipping RPC results of
# lost requests (i.e., ones from the previous connection session)
try:
if not self._process_msg(
inner_msg_id, sequence, reader, updates):
if not self._process_msg(inner_msg_id, sequence, reader, state):
reader.set_position(begin_position + inner_length)
except:
# If any error is raised, something went wrong; skip the packet
@@ -366,14 +356,13 @@ class MtProtoSender:
self._logger.debug('Lost request will be skipped.')
return False
def _handle_gzip_packed(self, msg_id, sequence, reader, updates):
def _handle_gzip_packed(self, msg_id, sequence, reader, state):
self._logger.debug('Handling gzip packed data')
reader.read_int(signed=False) # code
packed_data = reader.tgread_bytes()
unpacked_data = gzip.decompress(packed_data)
with BinaryReader(unpacked_data) as compressed_reader:
return self._process_msg(
msg_id, sequence, compressed_reader, updates)
return self._process_msg(msg_id, sequence, compressed_reader, state)
# endregion