diff --git a/telethon/extensions/tcp_client.py b/telethon/extensions/tcp_client.py index 134c36ae..672c0a0f 100644 --- a/telethon/extensions/tcp_client.py +++ b/telethon/extensions/tcp_client.py @@ -1,9 +1,12 @@ # Python rough implementation of a C# TCP client +import asyncio import errno import socket from datetime import timedelta from io import BytesIO, BufferedWriter +loop = asyncio.get_event_loop() + class TcpClient: def __init__(self, proxy=None, timeout=timedelta(seconds=5)): @@ -30,7 +33,7 @@ class TcpClient: self._socket.settimeout(self.timeout) - def connect(self, ip, port): + async def connect(self, ip, port): """Connects to the specified IP and port number. 'timeout' must be given in seconds """ @@ -44,7 +47,7 @@ class TcpClient: while not self._socket: self._recreate_socket(mode) - self._socket.connect(address) + await loop.sock_connect(self._socket, address) break # Successful connection, stop retrying to connect except OSError as e: # There are some errors that we know how to handle, and @@ -72,15 +75,13 @@ class TcpClient: finally: self._socket = None - def write(self, data): + async def write(self, data): """Writes (sends) the specified bytes to the connected peer""" if self._socket is None: raise ConnectionResetError() - # TODO Timeout may be an issue when sending the data, Changed in v3.5: - # The socket timeout is now the maximum total duration to send all data. try: - self._socket.sendall(data) + await loop.sock_sendall(self._socket, data) except socket.timeout as e: raise TimeoutError() from e except BrokenPipeError: @@ -91,14 +92,9 @@ class TcpClient: else: raise - def read(self, size): + async def read(self, size): """Reads (receives) a whole block of 'size bytes from the connected peer. - - A timeout can be specified, which will cancel the operation if - no data has been read in the specified time. If data was read - and it's waiting for more, the timeout will NOT cancel the - operation. Set to None for no timeout """ if self._socket is None: raise ConnectionResetError() @@ -108,7 +104,7 @@ class TcpClient: bytes_left = size while bytes_left != 0: try: - partial = self._socket.recv(bytes_left) + partial = await loop.sock_recv(self._socket, bytes_left) except socket.timeout as e: raise TimeoutError() from e except OSError as e: diff --git a/telethon/network/authenticator.py b/telethon/network/authenticator.py index 1081897a..f46f5430 100644 --- a/telethon/network/authenticator.py +++ b/telethon/network/authenticator.py @@ -17,21 +17,21 @@ from ..tl.functions import ( ) -def do_authentication(connection, retries=5): +async def do_authentication(connection, retries=5): if not retries or retries < 0: retries = 1 last_error = None while retries: try: - return _do_authentication(connection) + return await _do_authentication(connection) except (SecurityError, AssertionError, NotImplementedError) as e: last_error = e retries -= 1 raise last_error -def _do_authentication(connection): +async def _do_authentication(connection): """Executes the authentication process with the Telegram servers. If no error is raised, returns both the authorization key and the time offset. @@ -42,8 +42,8 @@ def _do_authentication(connection): req_pq_request = ReqPqRequest( nonce=int.from_bytes(os.urandom(16), 'big', signed=True) ) - sender.send(req_pq_request.to_bytes()) - with BinaryReader(sender.receive()) as reader: + await sender.send(req_pq_request.to_bytes()) + with BinaryReader(await sender.receive()) as reader: req_pq_request.on_response(reader) res_pq = req_pq_request.result @@ -90,10 +90,10 @@ def _do_authentication(connection): public_key_fingerprint=target_fingerprint, encrypted_data=cipher_text ) - sender.send(req_dh_params.to_bytes()) + await sender.send(req_dh_params.to_bytes()) # Step 2 response: DH Exchange - with BinaryReader(sender.receive()) as reader: + with BinaryReader(await sender.receive()) as reader: req_dh_params.on_response(reader) server_dh_params = req_dh_params.result @@ -157,10 +157,10 @@ def _do_authentication(connection): server_nonce=res_pq.server_nonce, encrypted_data=client_dh_encrypted, ) - sender.send(set_client_dh.to_bytes()) + await sender.send(set_client_dh.to_bytes()) # Step 3 response: Complete DH Exchange - with BinaryReader(sender.receive()) as reader: + with BinaryReader(await sender.receive()) as reader: set_client_dh.on_response(reader) dh_gen = set_client_dh.result diff --git a/telethon/network/connection.py b/telethon/network/connection.py index 2500c0c1..77a3c87b 100644 --- a/telethon/network/connection.py +++ b/telethon/network/connection.py @@ -1,14 +1,13 @@ +import errno import os import struct from datetime import timedelta -from zlib import crc32 from enum import Enum - -import errno +from zlib import crc32 from ..crypto import AESModeCTR -from ..extensions import TcpClient from ..errors import InvalidChecksumError +from ..extensions import TcpClient class ConnectionMode(Enum): @@ -74,9 +73,9 @@ class Connection: setattr(self, 'write', self._write_plain) setattr(self, 'read', self._read_plain) - def connect(self, ip, port): + async def connect(self, ip, port): try: - self.conn.connect(ip, port) + await self.conn.connect(ip, port) except OSError as e: if e.errno == errno.EISCONN: return # Already connected, no need to re-set everything up @@ -85,16 +84,16 @@ class Connection: self._send_counter = 0 if self._mode == ConnectionMode.TCP_ABRIDGED: - self.conn.write(b'\xef') + await self.conn.write(b'\xef') elif self._mode == ConnectionMode.TCP_INTERMEDIATE: - self.conn.write(b'\xee\xee\xee\xee') + await self.conn.write(b'\xee\xee\xee\xee') elif self._mode == ConnectionMode.TCP_OBFUSCATED: - self._setup_obfuscation() + await self._setup_obfuscation() def get_timeout(self): return self.conn.timeout - def _setup_obfuscation(self): + async def _setup_obfuscation(self): # Obfuscated messages secrets cannot start with any of these keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4) while True: @@ -119,7 +118,7 @@ class Connection: self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv) random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64] - self.conn.write(bytes(random)) + await self.conn.write(bytes(random)) def is_connected(self): return self.conn.connected @@ -135,20 +134,23 @@ class Connection: # region Receive message implementations - def recv(self): + async def recv(self): """Receives and unpacks a message""" # Default implementation is just an error raise ValueError('Invalid connection mode specified: ' + str(self._mode)) - def _recv_tcp_full(self): - packet_length_bytes = self.read(4) + async def _recv_tcp_full(self): + # TODO We don't want another call to this method that could + # potentially await on another self.read(n). Is this guaranteed + # by asyncio? + packet_length_bytes = await self.read(4) packet_length = int.from_bytes(packet_length_bytes, 'little') - seq_bytes = self.read(4) + seq_bytes = await self.read(4) seq = int.from_bytes(seq_bytes, 'little') - body = self.read(packet_length - 12) - checksum = int.from_bytes(self.read(4), 'little') + body = await self.read(packet_length - 12) + checksum = int.from_bytes(await self.read(4), 'little') valid_checksum = crc32(packet_length_bytes + seq_bytes + body) if checksum != valid_checksum: @@ -156,72 +158,70 @@ class Connection: return body - def _recv_intermediate(self): - return self.read(int.from_bytes(self.read(4), 'little')) + async def _recv_intermediate(self): + return await self.read(int.from_bytes(self.read(4), 'little')) - def _recv_abridged(self): + async def _recv_abridged(self): length = int.from_bytes(self.read(1), 'little') if length >= 127: length = int.from_bytes(self.read(3) + b'\0', 'little') - return self.read(length << 2) + return await self.read(length << 2) # endregion # region Send message implementations - def send(self, message): + async def send(self, message): """Encapsulates and sends the given message""" # Default implementation is just an error raise ValueError('Invalid connection mode specified: ' + str(self._mode)) - def _send_tcp_full(self, message): + async def _send_tcp_full(self, message): # https://core.telegram.org/mtproto#tcp-transport # total length, sequence number, packet and checksum (CRC32) length = len(message) + 12 data = struct.pack('> 2 if length < 127: length = struct.pack('B', length) else: length = b'\x7f' + int.to_bytes(length, 3, 'little') - self.write(length + message) + await self.write(length + message) # endregion # region Read implementations - def read(self, length): + async def read(self, length): raise ValueError('Invalid connection mode specified: ' + str(self._mode)) - def _read_plain(self, length): - return self.conn.read(length) + async def _read_plain(self, length): + return await self.conn.read(length) - def _read_obfuscated(self, length): - return self._aes_decrypt.encrypt( - self.conn.read(length) - ) + async def _read_obfuscated(self, length): + return await self._aes_decrypt.encrypt(self.conn.read(length)) # endregion # region Write implementations - def write(self, data): + async def write(self, data): raise ValueError('Invalid connection mode specified: ' + str(self._mode)) - def _write_plain(self, data): - self.conn.write(data) + async def _write_plain(self, data): + await self.conn.write(data) - def _write_obfuscated(self, data): - self.conn.write(self._aes_encrypt.encrypt(data)) + async def _write_obfuscated(self, data): + await self.conn.write(self._aes_encrypt.encrypt(data)) # endregion diff --git a/telethon/network/mtproto_plain_sender.py b/telethon/network/mtproto_plain_sender.py index c7c021be..9089a72d 100644 --- a/telethon/network/mtproto_plain_sender.py +++ b/telethon/network/mtproto_plain_sender.py @@ -16,23 +16,23 @@ class MtProtoPlainSender: self._last_msg_id = 0 self._connection = connection - def connect(self): - self._connection.connect() + async def connect(self): + await self._connection.connect() def disconnect(self): self._connection.close() - def send(self, data): + async def send(self, data): """Sends a plain packet (auth_key_id = 0) containing the given message body (data) """ - self._connection.send( + await self._connection.send( struct.pack('