Apply @andr-04 asyncio commits to TcpClient

This commit is contained in:
Lonami Exo 2018-06-14 16:08:23 +02:00
parent 3ce8b17193
commit c9ea1bafc0

View File

@ -8,22 +8,37 @@ This class is also not concerned about disconnections or retries of
any sort, nor any other kind of errors such as connecting twice. any sort, nor any other kind of errors such as connecting twice.
""" """
import asyncio import asyncio
import errno
import logging import logging
import socket import socket
from datetime import timedelta
from io import BytesIO from io import BytesIO
CONN_RESET_ERRNOS = {
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
errno.EINVAL, errno.ENOTCONN, errno.EHOSTUNREACH,
errno.ECONNREFUSED, errno.ECONNRESET, errno.ECONNABORTED,
errno.ENETDOWN, errno.ENETRESET, errno.ECONNABORTED,
errno.EHOSTDOWN, errno.EPIPE, errno.ESHUTDOWN
}
# catched: EHOSTUNREACH, ECONNREFUSED, ECONNRESET, ENETUNREACH
# ConnectionError: EPIPE, ESHUTDOWN, ECONNABORTED, ECONNREFUSED, ECONNRESET
try: try:
import socks import socks
except ImportError: except ImportError:
socks = None socks = None
__log__ = logging.getLogger(__name__) __log__ = logging.getLogger(__name__)
class TcpClient: class TcpClient:
"""A simple TCP client to ease the work with sockets and proxies.""" """A simple TCP client to ease the work with sockets and proxies."""
def __init__(self, proxy=None, timeout=5):
class SocketClosed(ConnectionError):
pass
def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None):
""" """
Initializes the TCP client. Initializes the TCP client.
@ -32,7 +47,9 @@ class TcpClient:
""" """
self.proxy = proxy self.proxy = proxy
self._socket = None self._socket = None
self._loop = asyncio.get_event_loop() self._loop = loop or asyncio.get_event_loop()
self._closed = asyncio.Event(loop=self._loop)
self._closed.set()
if isinstance(timeout, (int, float)): if isinstance(timeout, (int, float)):
self.timeout = float(timeout) self.timeout = float(timeout)
@ -57,7 +74,7 @@ class TcpClient:
async def connect(self, ip, port): async def connect(self, ip, port):
""" """
Tries connecting to IP:port. Tries connecting to IP:port unless an OSError is raised.
:param ip: the IP to connect to. :param ip: the IP to connect to.
:param port: the port to connect to. :param port: the port to connect to.
@ -68,42 +85,78 @@ class TcpClient:
else: else:
mode, address = socket.AF_INET, (ip, port) mode, address = socket.AF_INET, (ip, port)
if self._socket is None: try:
self._socket = self._create_socket(mode, self.proxy) if self._socket is None:
self._socket = self._create_socket(mode, self.proxy)
await asyncio.wait_for(self._loop.sock_connect(self._socket, address), await asyncio.wait_for(
self.timeout, loop=self._loop) self._loop.sock_connect(self._socket, address),
timeout=self.timeout,
loop=self._loop
)
self._closed.clear()
except asyncio.TimeoutError as e:
raise TimeoutError() from e
except OSError as e:
if e.errno in CONN_RESET_ERRNOS:
raise ConnectionResetError() from e
else:
raise
@property @property
def is_connected(self): def is_connected(self):
"""Determines whether the client is connected or not.""" """Determines whether the client is connected or not."""
# TODO fileno() is >= 0 even before calling sock_connect! return not self._closed.is_set()
return self._socket is not None and self._socket.fileno() >= 0
def close(self): def close(self):
"""Closes the connection.""" """Closes the connection."""
if self._socket is not None: try:
try: if self._socket is not None:
if self.is_connected:
self._socket.shutdown(socket.SHUT_RDWR)
self._socket.close() self._socket.close()
except OSError: except OSError:
pass pass # Ignore ENOTCONN, EBADF, and any other error when closing
finally: finally:
self._socket = None self._socket = None
self._closed.set()
async def _wait_timeout_or_close(self, coro):
"""
Waits for the given coroutine to complete unless
the socket is closed or `self.timeout` expires.
"""
done, running = await asyncio.wait(
[coro, self._closed.wait()],
timeout=self.timeout,
return_when=asyncio.FIRST_COMPLETED,
loop=self._loop
)
for r in running:
r.cancel()
if not self.is_connected:
raise self.SocketClosed()
if not done:
raise TimeoutError()
return done.pop().result()
async def write(self, data): async def write(self, data):
""" """
Writes (sends) the specified bytes to the connected peer. Writes (sends) the specified bytes to the connected peer.
:param data: the data to send. :param data: the data to send.
""" """
if not self.is_connected: if not self.is_connected:
raise ConnectionError() raise ConnectionResetError('Not connected')
try:
await asyncio.wait_for( await self._wait_timeout_or_close(self.sock_sendall(data))
self.sock_sendall(data), except self.SocketClosed:
timeout=self.timeout, raise ConnectionResetError('Socket has closed')
loop=self._loop except OSError as e:
) __log__.info('OSError "%s" while writing data', e)
if e.errno in CONN_RESET_ERRNOS:
raise ConnectionResetError() from e
else:
raise
async def read(self, size): async def read(self, size):
""" """
@ -113,16 +166,32 @@ class TcpClient:
:return: the read data with len(data) == size. :return: the read data with len(data) == size.
""" """
if not self.is_connected: if not self.is_connected:
raise ConnectionError() raise ConnectionResetError('Not connected')
with BytesIO() as buffer: with BytesIO() as buffer:
bytes_left = size bytes_left = size
while bytes_left != 0: while bytes_left != 0:
partial = await asyncio.wait_for( try:
self.sock_recv(bytes_left), partial = await self._wait_timeout_or_close(
timeout=self.timeout, self.sock_recv(bytes_left)
loop=self._loop )
) except TimeoutError as e:
if bytes_left < size:
__log__.warning(
'socket timeout "%s" when %d/%d had been received',
e, size - bytes_left, size
)
raise
except self.SocketClosed:
raise ConnectionResetError(
'Socket has closed while reading data'
)
except OSError as e:
if e.errno in CONN_RESET_ERRNOS:
raise ConnectionResetError() from e
else:
raise
if not partial: if not partial:
raise ConnectionResetError() raise ConnectionResetError()
@ -141,7 +210,7 @@ class TcpClient:
def _sock_recv(self, fut, registered_fd, n): def _sock_recv(self, fut, registered_fd, n):
if registered_fd is not None: if registered_fd is not None:
self._loop.remove_reader(registered_fd) self._loop.remove_reader(registered_fd)
if fut.cancelled(): if fut.cancelled() or self._socket is None:
return return
try: try:
@ -165,7 +234,7 @@ class TcpClient:
def _sock_sendall(self, fut, registered_fd, data): def _sock_sendall(self, fut, registered_fd, data):
if registered_fd: if registered_fd:
self._loop.remove_writer(registered_fd) self._loop.remove_writer(registered_fd)
if fut.cancelled(): if fut.cancelled() or self._socket is None:
return return
try: try: