Reduce __enter__/__exit__ boilerplate for sync ctx managers

This commit is contained in:
Lonami Exo
2019-04-13 10:53:33 +02:00
parent badefcec48
commit 9090ede5db
5 changed files with 33 additions and 51 deletions

View File

@@ -2,7 +2,7 @@ import functools
import inspect
from .users import UserMethods, _NOT_A_REQUEST
from .. import utils
from .. import helpers, utils
from ..tl import functions, TLRequest
@@ -31,15 +31,6 @@ class _TakeoutClient:
def success(self, value):
self.__success = value
def __enter__(self):
if self.__client.loop.is_running():
raise RuntimeError(
'You must use "async with" if the event loop '
'is running (i.e. you are inside an "async def")'
)
return self.__client.loop.run_until_complete(self.__aenter__())
async def __aenter__(self):
# Enter/Exit behaviour is "overrode", we don't want to call start.
client = self.__client
@@ -50,9 +41,6 @@ class _TakeoutClient:
"takeout for the current session still not been finished yet.")
return self
def __exit__(self, *args):
return self.__client.loop.run_until_complete(self.__aexit__(*args))
async def __aexit__(self, exc_type, exc_value, traceback):
if self.__success is None and self.__finalize:
self.__success = exc_type is None
@@ -64,6 +52,9 @@ class _TakeoutClient:
raise ValueError("Failed to finish the takeout.")
self.session.takeout_id = None
__enter__ = helpers._sync_enter
__exit__ = helpers._sync_exit
async def __call__(self, request, ordered=False):
takeout_id = self.__client.session.takeout_id
if takeout_id is None:

View File

@@ -536,23 +536,13 @@ class AuthMethods(MessageParseMethods, UserMethods):
# region with blocks
def __enter__(self):
if self._loop.is_running():
raise RuntimeError(
'You must use "async with" if the event loop '
'is running (i.e. you are inside an "async def")'
)
return self.start()
async def __aenter__(self):
return await self.start()
def __exit__(self, *args):
# No loop.run_until_complete; it's already syncified
self.disconnect()
async def __aexit__(self, *args):
await self.disconnect()
__enter__ = helpers._sync_enter
__exit__ = helpers._sync_exit
# endregion

View File

@@ -3,7 +3,7 @@ import itertools
import string
from .users import UserMethods
from .. import utils
from .. import helpers, utils
from ..requestiter import RequestIter
from ..tl import types, functions, custom
@@ -69,17 +69,8 @@ class _ChatAction:
self._task = None
def __enter__(self):
if self._client.loop.is_running():
raise RuntimeError(
'You must use "async with" if the event loop '
'is running (i.e. you are inside an "async def")'
)
return self._client.loop.run_until_complete(self.__aenter__())
def __exit__(self, *args):
return self._client.loop.run_until_complete(self.__aexit__(*args))
__enter__ = helpers._sync_enter
__exit__ = helpers._sync_exit
async def _update(self):
try: