diff --git a/telethon/extensions/tcp_client.py b/telethon/extensions/tcp_client.py index 850b515d..4a9d8827 100644 --- a/telethon/extensions/tcp_client.py +++ b/telethon/extensions/tcp_client.py @@ -8,22 +8,37 @@ This class is also not concerned about disconnections or retries of any sort, nor any other kind of errors such as connecting twice. """ import asyncio +import errno import logging import socket +from datetime import timedelta from io import BytesIO +CONN_RESET_ERRNOS = { + errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH, + errno.EINVAL, errno.ENOTCONN, errno.EHOSTUNREACH, + errno.ECONNREFUSED, errno.ECONNRESET, errno.ECONNABORTED, + errno.ENETDOWN, errno.ENETRESET, errno.ECONNABORTED, + errno.EHOSTDOWN, errno.EPIPE, errno.ESHUTDOWN +} +# catched: EHOSTUNREACH, ECONNREFUSED, ECONNRESET, ENETUNREACH +# ConnectionError: EPIPE, ESHUTDOWN, ECONNABORTED, ECONNREFUSED, ECONNRESET + try: import socks except ImportError: socks = None - __log__ = logging.getLogger(__name__) class TcpClient: """A simple TCP client to ease the work with sockets and proxies.""" - def __init__(self, proxy=None, timeout=5): + + class SocketClosed(ConnectionError): + pass + + def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None): """ Initializes the TCP client. @@ -32,7 +47,9 @@ class TcpClient: """ self.proxy = proxy self._socket = None - self._loop = asyncio.get_event_loop() + self._loop = loop or asyncio.get_event_loop() + self._closed = asyncio.Event(loop=self._loop) + self._closed.set() if isinstance(timeout, (int, float)): self.timeout = float(timeout) @@ -57,7 +74,7 @@ class TcpClient: async def connect(self, ip, port): """ - Tries connecting to IP:port. + Tries connecting to IP:port unless an OSError is raised. :param ip: the IP to connect to. :param port: the port to connect to. @@ -68,42 +85,78 @@ class TcpClient: else: mode, address = socket.AF_INET, (ip, port) - if self._socket is None: - self._socket = self._create_socket(mode, self.proxy) + try: + if self._socket is None: + self._socket = self._create_socket(mode, self.proxy) - await asyncio.wait_for(self._loop.sock_connect(self._socket, address), - self.timeout, loop=self._loop) + await asyncio.wait_for( + self._loop.sock_connect(self._socket, address), + timeout=self.timeout, + loop=self._loop + ) + self._closed.clear() + except asyncio.TimeoutError as e: + raise TimeoutError() from e + except OSError as e: + if e.errno in CONN_RESET_ERRNOS: + raise ConnectionResetError() from e + else: + raise @property def is_connected(self): """Determines whether the client is connected or not.""" - # TODO fileno() is >= 0 even before calling sock_connect! - return self._socket is not None and self._socket.fileno() >= 0 + return not self._closed.is_set() def close(self): """Closes the connection.""" - if self._socket is not None: - try: + try: + if self._socket is not None: + if self.is_connected: + self._socket.shutdown(socket.SHUT_RDWR) self._socket.close() - except OSError: - pass - finally: - self._socket = None + except OSError: + pass # Ignore ENOTCONN, EBADF, and any other error when closing + finally: + self._socket = None + self._closed.set() + + async def _wait_timeout_or_close(self, coro): + """ + Waits for the given coroutine to complete unless + the socket is closed or `self.timeout` expires. + """ + done, running = await asyncio.wait( + [coro, self._closed.wait()], + timeout=self.timeout, + return_when=asyncio.FIRST_COMPLETED, + loop=self._loop + ) + for r in running: + r.cancel() + if not self.is_connected: + raise self.SocketClosed() + if not done: + raise TimeoutError() + return done.pop().result() async def write(self, data): """ Writes (sends) the specified bytes to the connected peer. - :param data: the data to send. """ if not self.is_connected: - raise ConnectionError() - - await asyncio.wait_for( - self.sock_sendall(data), - timeout=self.timeout, - loop=self._loop - ) + raise ConnectionResetError('Not connected') + try: + await self._wait_timeout_or_close(self.sock_sendall(data)) + except self.SocketClosed: + raise ConnectionResetError('Socket has closed') + except OSError as e: + __log__.info('OSError "%s" while writing data', e) + if e.errno in CONN_RESET_ERRNOS: + raise ConnectionResetError() from e + else: + raise async def read(self, size): """ @@ -113,16 +166,32 @@ class TcpClient: :return: the read data with len(data) == size. """ if not self.is_connected: - raise ConnectionError() + raise ConnectionResetError('Not connected') with BytesIO() as buffer: bytes_left = size while bytes_left != 0: - partial = await asyncio.wait_for( - self.sock_recv(bytes_left), - timeout=self.timeout, - loop=self._loop - ) + try: + partial = await self._wait_timeout_or_close( + self.sock_recv(bytes_left) + ) + except TimeoutError as e: + if bytes_left < size: + __log__.warning( + 'socket timeout "%s" when %d/%d had been received', + e, size - bytes_left, size + ) + raise + except self.SocketClosed: + raise ConnectionResetError( + 'Socket has closed while reading data' + ) + except OSError as e: + if e.errno in CONN_RESET_ERRNOS: + raise ConnectionResetError() from e + else: + raise + if not partial: raise ConnectionResetError() @@ -141,7 +210,7 @@ class TcpClient: def _sock_recv(self, fut, registered_fd, n): if registered_fd is not None: self._loop.remove_reader(registered_fd) - if fut.cancelled(): + if fut.cancelled() or self._socket is None: return try: @@ -165,7 +234,7 @@ class TcpClient: def _sock_sendall(self, fut, registered_fd, data): if registered_fd: self._loop.remove_writer(registered_fd) - if fut.cancelled(): + if fut.cancelled() or self._socket is None: return try: