diff --git a/main.py b/main.py index 4cc390ec..0a18496b 100755 --- a/main.py +++ b/main.py @@ -74,8 +74,8 @@ if __name__ == '__main__': date = datetime.fromtimestamp(msg.date) print('[{}:{}] {}: {}'.format(date.hour, date.minute, name, msg.message)) - # Send chat message - else: + # Send chat message (if any) + elif msg: client.send_message(input_peer, msg, markdown=True, no_web_page=True) print('Thanks for trying the interactive example! Exiting...') diff --git a/network/mtproto_sender.py b/network/mtproto_sender.py index 5d42a537..cf9f663d 100755 --- a/network/mtproto_sender.py +++ b/network/mtproto_sender.py @@ -3,7 +3,7 @@ import gzip from errors import * from time import sleep -from threading import Thread +from threading import Thread, Lock import utils from crypto import AES @@ -14,22 +14,20 @@ from tl.all_tlobjects import tlobjects class MtProtoSender: """MTProto Mobile Protocol sender (https://core.telegram.org/mtproto/description)""" - def __init__(self, transport, session, check_updates_delay=0.1): - """If check_updates_delay is None, no updates will be checked. - Otherwise, specifies every how often updates should be checked""" - + def __init__(self, transport, session, check_updates=True): self.transport = transport self.session = session self.need_confirmation = [] # Message IDs that need confirmation self.on_update_handlers = [] - # Set up updates thread, if the delay is not None - self.check_updates_delay = check_updates_delay - if check_updates_delay: + # Store a Lock instance to make this class safely multi-threaded + self.lock = Lock() + + if check_updates: self.updates_thread = Thread(target=self.updates_thread_method, name='Updates thread') self.updates_thread_running = True - self.updates_thread_paused = True + self.updates_thread_receiving = False self.updates_thread.start() @@ -60,14 +58,23 @@ class MtProtoSender: """Sends the specified MTProtoRequest, previously sending any message which needed confirmation. This also pauses the updates thread""" - # Pause the updates thread: we cannot use self.transport at the same time! - self.pause_updates_thread() + # Only cancel the receive *if* it was the + # updates thread who was receiving. We do + # not want to cancel other pending requests! + if self.updates_thread_receiving: + self.transport.cancel_receive() + + # Now only us can be using this method + self.lock.acquire() # If any message needs confirmation send an AckRequest first if self.need_confirmation: - msgs_ack = MsgsAck(self.need_confirmation[:]) + msgs_ack = MsgsAck(self.need_confirmation) + with BinaryWriter() as writer: + msgs_ack.on_send(writer) + self.send_packet(writer.get_bytes(), msgs_ack) + del self.need_confirmation[:] - self.send(msgs_ack) # Finally send our packed request with BinaryWriter() as writer: @@ -83,23 +90,27 @@ class MtProtoSender: """Receives the specified MTProtoRequest ("fills in it" the received data). This also restores the updates thread""" - # Don't stop receiving until we get the request we wanted - while not request.confirm_received: - seq, body = self.transport.receive() - message, remote_msg_id, remote_sequence = self.decode_msg(body) + try: + # Don't stop trying to receive until we get the request we wanted + while not request.confirm_received: + seq, body = self.transport.receive() + message, remote_msg_id, remote_sequence = self.decode_msg(body) - with BinaryReader(message) as reader: - self.process_msg(remote_msg_id, remote_sequence, reader, request) + with BinaryReader(message) as reader: + self.process_msg(remote_msg_id, remote_sequence, reader, request) - # Once we have our request, restore the updates thread - self.restore_updates_thread() + finally: + # Once we are done trying to get our request, + # restore the updates thread and release the lock + self.lock.release() # endregion # region Low level processing def send_packet(self, packet, request): - """Sends the given packet bytes with the additional information of the original request""" + """Sends the given packet bytes with the additional + information of the original request. This does NOT lock the threads!""" request.msg_id = self.session.get_new_msg_id() # First calculate plain_text to encrypt it @@ -276,22 +287,12 @@ class MtProtoSender: # endregion - def pause_updates_thread(self): - """Pauses the updates thread and sleeps until it's safe to continue""" - if not self.updates_thread_paused: - self.updates_thread_paused = True - self.transport.cancel_receive() - - def restore_updates_thread(self): - """Restores the updates thread""" - self.updates_thread_paused = False - - # TODO avoid, if possible using sleeps; Use thread locks instead def updates_thread_method(self): """This method will run until specified and listen for incoming updates""" while self.updates_thread_running: - if not self.updates_thread_paused: + with self.lock: try: + self.updates_thread_receiving = True seq, body = self.transport.receive() message, remote_msg_id, remote_sequence = self.decode_msg(body) @@ -301,4 +302,10 @@ class MtProtoSender: except ReadCancelledError: pass - sleep(self.transport.get_client_delay()) + self.updates_thread_receiving = False + + # If we are here, it is because the read was cancelled + # Sleep a bit just to give enough time for the other thread + # to acquire the lock. No need to sleep if we're not running anymore + if self.updates_thread_running: + sleep(0.1)