From e469258ab9f986d427702fe5310626ad4033a038 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 6 Jun 2018 20:41:01 +0200 Subject: [PATCH] Create a new MTProtoSender structure and its foundation This means that the TcpClient and the Connection (currently only ConnectionTcpFull) will no longer be concerned about handling errors, but the MTProtoSender will. The foundation of the library will now be based on asyncio. --- telethon/extensions/tcp_client.py | 235 ++++++++++++------------- telethon/helpers.py | 21 ++- telethon/network/connection/common.py | 17 +- telethon/network/connection/tcpfull.py | 8 +- telethon/network/mtprotosender.py | 144 +++++++++++++++ telethon/tl/tlobject.py | 2 +- 6 files changed, 284 insertions(+), 143 deletions(-) create mode 100644 telethon/network/mtprotosender.py diff --git a/telethon/extensions/tcp_client.py b/telethon/extensions/tcp_client.py index 8a5800fb..a12d33a9 100644 --- a/telethon/extensions/tcp_client.py +++ b/telethon/extensions/tcp_client.py @@ -1,31 +1,31 @@ """ This module holds a rough implementation of the C# TCP client. + +This class is **not** safe across several tasks since partial reads +may be ``await``'ed before being able to return the exact byte count. + +This class is also not concerned about disconnections or retries of +any sort, nor any other kind of errors such as connecting twice. """ -import errno +import asyncio import logging import socket -import time -from datetime import timedelta -from io import BytesIO, BufferedWriter -from threading import Lock +from io import BytesIO try: import socks except ImportError: socks = None -MAX_TIMEOUT = 15 # in seconds -CONN_RESET_ERRNOS = { - errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH, - errno.EINVAL, errno.ENOTCONN -} __log__ = logging.getLogger(__name__) +# TODO Except asyncio.TimeoutError, ConnectionError, OSError... +# ...for connect, write and read (in the upper levels, not here) class TcpClient: """A simple TCP client to ease the work with sockets and proxies.""" - def __init__(self, proxy=None, timeout=timedelta(seconds=5)): + def __init__(self, proxy=None, timeout=5): """ Initializes the TCP client. @@ -34,31 +34,32 @@ class TcpClient: """ self.proxy = proxy self._socket = None - self._closing_lock = Lock() + self._loop = asyncio.get_event_loop() - if isinstance(timeout, timedelta): - self.timeout = timeout.seconds - elif isinstance(timeout, (int, float)): + if isinstance(timeout, (int, float)): self.timeout = float(timeout) + elif hasattr(timeout, 'seconds'): + self.timeout = float(timeout.seconds) else: raise TypeError('Invalid timeout type: {}'.format(type(timeout))) - def _recreate_socket(self, mode): - if self.proxy is None: - self._socket = socket.socket(mode, socket.SOCK_STREAM) + @staticmethod + def _create_socket(mode, proxy): + if proxy is None: + s = socket.socket(mode, socket.SOCK_STREAM) else: import socks - self._socket = socks.socksocket(mode, socket.SOCK_STREAM) - if type(self.proxy) is dict: - self._socket.set_proxy(**self.proxy) + s = socks.socksocket(mode, socket.SOCK_STREAM) + if isinstance(proxy, dict): + s.set_proxy(**proxy) else: # tuple, list, etc. - self._socket.set_proxy(*self.proxy) + s.set_proxy(*proxy) + s.setblocking(False) + return s - self._socket.settimeout(self.timeout) - - def connect(self, ip, port): + async def connect(self, ip, port): """ - Tries connecting forever to IP:port unless an OSError is raised. + Tries connecting to IP:port. :param ip: the IP to connect to. :param port: the port to connect to. @@ -69,136 +70,116 @@ class TcpClient: else: mode, address = socket.AF_INET, (ip, port) - timeout = 1 - while True: - try: - while not self._socket: - self._recreate_socket(mode) + if self._socket is None: + self._socket = self._create_socket(mode, self.proxy) - self._socket.connect(address) - break # Successful connection, stop retrying to connect - except OSError as e: - __log__.info('OSError "%s" raised while connecting', e) - # Stop retrying to connect if proxy connection error occurred - if socks and isinstance(e, socks.ProxyConnectionError): - raise - # There are some errors that we know how to handle, and - # the loop will allow us to retry - if e.errno in (errno.EBADF, errno.ENOTSOCK, errno.EINVAL, - errno.ECONNREFUSED, # Windows-specific follow - getattr(errno, 'WSAEACCES', None)): - # Bad file descriptor, i.e. socket was closed, set it - # to none to recreate it on the next iteration - self._socket = None - time.sleep(timeout) - timeout *= 2 - if timeout > MAX_TIMEOUT: - raise - else: - raise + asyncio.wait_for(self._loop.sock_connect(self._socket, address), + self.timeout, loop=self._loop) - def _get_connected(self): + @property + def is_connected(self): """Determines whether the client is connected or not.""" return self._socket is not None and self._socket.fileno() >= 0 - connected = property(fget=_get_connected) - def close(self): """Closes the connection.""" - if self._closing_lock.locked(): - # Already closing, no need to close again (avoid None.close()) - return - - with self._closing_lock: + if self._socket is not None: try: - if self._socket is not None: - self._socket.shutdown(socket.SHUT_RDWR) - self._socket.close() - except OSError: - pass # Ignore ENOTCONN, EBADF, and any other error when closing + self._socket.shutdown(socket.SHUT_RDWR) + self._socket.close() finally: self._socket = None - def write(self, data): + async def write(self, data): """ Writes (sends) the specified bytes to the connected peer. :param data: the data to send. """ - if self._socket is None: - self._raise_connection_reset(None) + if not self.is_connected: + raise ConnectionError() - # TODO Timeout may be an issue when sending the data, Changed in v3.5: - # The socket timeout is now the maximum total duration to send all data. - try: - self._socket.sendall(data) - except socket.timeout as e: - __log__.debug('socket.timeout "%s" while writing data', e) - raise TimeoutError() from e - except ConnectionError as e: - __log__.info('ConnectionError "%s" while writing data', e) - self._raise_connection_reset(e) - except OSError as e: - __log__.info('OSError "%s" while writing data', e) - if e.errno in CONN_RESET_ERRNOS: - self._raise_connection_reset(e) - else: - raise + await asyncio.wait_for( + self.sock_sendall(data), + timeout=self.timeout, + loop=self._loop + ) - def read(self, size): + async def read(self, size): """ Reads (receives) a whole block of size bytes from the connected peer. :param size: the size of the block to be read. :return: the read data with len(data) == size. """ - if self._socket is None: - self._raise_connection_reset(None) + if not self.is_connected: + raise ConnectionError() - with BufferedWriter(BytesIO(), buffer_size=size) as buffer: + with BytesIO() as buffer: bytes_left = size while bytes_left != 0: - try: - partial = self._socket.recv(bytes_left) - except socket.timeout as e: - # These are somewhat common if the server has nothing - # to send to us, so use a lower logging priority. - if bytes_left < size: - __log__.warning( - 'socket.timeout "%s" when %d/%d had been received', - e, size - bytes_left, size - ) - else: - __log__.debug( - 'socket.timeout "%s" while reading data', e - ) - - raise TimeoutError() from e - except ConnectionError as e: - __log__.info('ConnectionError "%s" while reading data', e) - self._raise_connection_reset(e) - except OSError as e: - if e.errno != errno.EBADF and self._closing_lock.locked(): - # Ignore bad file descriptor while closing - __log__.info('OSError "%s" while reading data', e) - - if e.errno in CONN_RESET_ERRNOS: - self._raise_connection_reset(e) - else: - raise - - if len(partial) == 0: - self._raise_connection_reset(None) + partial = await asyncio.wait_for( + self.sock_recv(bytes_left), + timeout=self.timeout, + loop=self._loop + ) + if not partial == 0: + raise ConnectionResetError() buffer.write(partial) bytes_left -= len(partial) - # If everything went fine, return the read bytes - buffer.flush() - return buffer.raw.getvalue() + return buffer.getvalue() - def _raise_connection_reset(self, original): - """Disconnects the client and raises ConnectionResetError.""" - self.close() # Connection reset -> flag as socket closed - raise ConnectionResetError('The server has closed the connection.')\ - from original + # Due to recent https://github.com/python/cpython/pull/4386 + # Credit to @andr-04 for his original implementation + def sock_recv(self, n): + fut = self._loop.create_future() + self._sock_recv(fut, None, n) + return fut + + def _sock_recv(self, fut, registered_fd, n): + if registered_fd is not None: + self._loop.remove_reader(registered_fd) + if fut.cancelled(): + return + + try: + data = self._socket.recv(n) + except (BlockingIOError, InterruptedError): + fd = self._socket.fileno() + self._loop.add_reader(fd, self._sock_recv, fut, fd, n) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(data) + + def sock_sendall(self, data): + fut = self._loop.create_future() + if data: + self._sock_sendall(fut, None, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered_fd, data): + if registered_fd: + self._loop.remove_writer(registered_fd) + if fut.cancelled(): + return + + try: + n = self._socket.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + fd = self._socket.fileno() + self._loop.add_writer(fd, self._sock_sendall, fut, fd, data) diff --git a/telethon/helpers.py b/telethon/helpers.py index 9ca91e4f..c76f93ec 100644 --- a/telethon/helpers.py +++ b/telethon/helpers.py @@ -3,9 +3,9 @@ import os import struct from hashlib import sha1, sha256 -from telethon.crypto import AES -from telethon.errors import SecurityError -from telethon.extensions import BinaryReader +from .crypto import AES +from .errors import SecurityError, BrokenAuthKeyError +from .extensions import BinaryReader # region Multiple utilities @@ -46,15 +46,22 @@ def pack_message(session, message): return key_id + msg_key + AES.encrypt_ige(data + padding, aes_key, aes_iv) -def unpack_message(session, reader): +def unpack_message(session, body): """Unpacks a message following MtProto 2.0 guidelines""" # See https://core.telegram.org/mtproto/description - if reader.read_long(signed=False) != session.auth_key.key_id: + if len(body) < 8: + if body == b'l\xfe\xff\xff': + raise BrokenAuthKeyError() + else: + raise BufferError("Can't decode packet ({})".format(body)) + + key_id = struct.unpack('