Handle all entity types on isinstance checks

Only the uses of `isinstance` against `InputPeer*` types were
reviewed. Notably, `utils` is exempt on this because it needs
to deal with everything on a case-by-case basis.

Since the addition of `*FromMessage` peers, any manual `isinstance`
checks to determine the type were prone to breaking or being
forgotten to be updated, so a common `helpers._entity_type()`
method was made to share this logic.

Since the conversion to `Peer` would be too expensive, a simpler
check against the name is made, which should be fast and cheap.
This commit is contained in:
Lonami Exo
2019-12-23 13:52:07 +01:00
parent 627e176f8e
commit fa736f81af
9 changed files with 97 additions and 37 deletions

View File

@@ -108,8 +108,8 @@ class _ParticipantsIter(RequestIter):
filter = filter()
entity = await self.client.get_input_entity(entity)
if search and (filter
or not isinstance(entity, types.InputPeerChannel)):
ty = helpers._entity_type(entity)
if search and (filter or ty != helpers._EntityType.CHANNEL):
# We need to 'search' ourselves unless we have a PeerChannel
search = search.casefold()
@@ -123,7 +123,7 @@ class _ParticipantsIter(RequestIter):
# Only used for channels, but we should always set the attribute
self.requests = []
if isinstance(entity, types.InputPeerChannel):
if ty == helpers._EntityType.CHANNEL:
self.total = (await self.client(
functions.channels.GetFullChannelRequest(entity)
)).full_chat.participants_count
@@ -149,7 +149,7 @@ class _ParticipantsIter(RequestIter):
hash=0
))
elif isinstance(entity, types.InputPeerChat):
elif ty == helpers._EntityType.CHAT:
full = await self.client(
functions.messages.GetFullChatRequest(entity.chat_id))
if not isinstance(
@@ -281,7 +281,8 @@ class _ProfilePhotoIter(RequestIter):
self, entity, offset, max_id
):
entity = await self.client.get_input_entity(entity)
if isinstance(entity, (types.InputPeerUser, types.InputPeerSelf)):
ty = helpers._entity_type(entity)
if ty == helpers._EntityType.USER:
self.request = functions.photos.GetUserPhotosRequest(
entity,
offset=offset,
@@ -864,7 +865,8 @@ class ChatMethods:
"""
entity = await self.get_input_entity(entity)
user = await self.get_input_entity(user)
if not isinstance(user, (types.InputPeerUser, types.InputPeerSelf)):
ty = helpers._entity_type(user)
if ty != helpers._EntityType.USER:
raise ValueError('You must pass a user entity')
perm_names = (
@@ -872,7 +874,8 @@ class ChatMethods:
'ban_users', 'invite_users', 'pin_messages', 'add_admins'
)
if isinstance(entity, types.InputPeerChannel):
ty = helpers._entity_type(entity)
if ty == helpers._EntityType.CHANNEL:
# If we try to set these permissions in a megagroup, we
# would get a RIGHT_FORBIDDEN. However, it makes sense
# that an admin can post messages, so we want to avoid the error
@@ -894,7 +897,7 @@ class ChatMethods:
for name in perm_names
}), rank=title or ''))
elif isinstance(entity, types.InputPeerChat):
elif ty == helpers._EntityType.CHAT:
# If the user passed any permission in a small
# group chat, they must be a full admin to have it.
if is_admin is None:
@@ -1015,7 +1018,8 @@ class ChatMethods:
await client.edit_permissions(chat, user)
"""
entity = await self.get_input_entity(entity)
if not isinstance(entity, types.InputPeerChannel):
ty = helpers._entity_type(entity)
if ty != helpers._EntityType.CHANNEL:
raise ValueError('You must pass either a channel or a supergroup')
rights = types.ChatBannedRights(
@@ -1040,12 +1044,13 @@ class ChatMethods:
))
user = await self.get_input_entity(user)
ty = helpers._entity_type(user)
if ty != helpers._EntityType.USER:
raise ValueError('You must pass a user entity')
if isinstance(user, types.InputPeerSelf):
raise ValueError('You cannot restrict yourself')
if not isinstance(user, types.InputPeerUser):
raise ValueError('You must pass a user entity')
return await self(functions.channels.EditBannedRequest(
channel=entity,
user_id=user,
@@ -1086,12 +1091,13 @@ class ChatMethods:
"""
entity = await self.get_input_entity(entity)
user = await self.get_input_entity(user)
if not isinstance(user, (types.InputPeerUser, types.InputPeerSelf)):
if helpers._entity_type(user) != helpers._EntityType.USER:
raise ValueError('You must pass a user entity')
if isinstance(entity, types.InputPeerChat):
ty = helpers._entity_type(entity)
if ty == helpers._EntityType.CHAT:
await self(functions.messages.DeleteChatUserRequest(entity.chat_id, user))
elif isinstance(entity, types.InputPeerChannel):
elif ty == helpers._EntityType.CHANNEL:
if isinstance(user, types.InputPeerSelf):
await self(functions.channels.LeaveChannelRequest(entity))
else:

View File

@@ -3,7 +3,7 @@ import inspect
import itertools
import typing
from .. import utils, hints
from .. import helpers, utils, hints
from ..requestiter import RequestIter
from ..tl import types, functions, custom
@@ -436,10 +436,11 @@ class DialogMethods:
await client.delete_dialog('username')
"""
entity = await self.get_input_entity(entity)
if isinstance(entity, types.InputPeerChannel):
ty = helpers._entity_type(entity)
if ty == helpers._EntityType.CHANNEL:
return await self(functions.channels.LeaveChannelRequest(entity))
if isinstance(entity, types.InputPeerChat):
if ty == helpers._EntityType.CHAT:
result = await self(functions.messages.DeleteChatUserRequest(
entity.chat_id, types.InputUserSelf()))
else:

View File

@@ -257,7 +257,8 @@ class DownloadMethods:
# See issue #500, Android app fails as of v4.6.0 (1155).
# The fix seems to be using the full channel chat photo.
ie = await self.get_input_entity(entity)
if isinstance(ie, types.InputPeerChannel):
ty = helpers._entity_type(ie)
if ty == helpers._EntityType.CHANNEL:
full = await self(functions.channels.GetFullChannelRequest(ie))
return await self._download_photo(
full.full_chat.chat_photo, file,

View File

@@ -2,7 +2,7 @@ import itertools
import re
import typing
from .. import utils
from .. import helpers, utils
from ..tl import types
if typing.TYPE_CHECKING:
@@ -134,7 +134,7 @@ class MessageParseMethods:
id_to_message[update.message.id] = update.message
elif (isinstance(update, types.UpdateEditMessage)
and not isinstance(request.peer, types.InputPeerChannel)):
and helpers._entity_type(request.peer) != helpers._EntityType.CHANNEL):
if request.id == update.message.id:
update.message._finish_init(self, entities, input_chat)
return update.message

View File

@@ -2,7 +2,7 @@ import inspect
import itertools
import typing
from .. import utils, errors, hints
from .. import helpers, utils, errors, hints
from ..requestiter import RequestIter
from ..tl import types, functions
@@ -57,8 +57,8 @@ class _MessagesIter(RequestIter):
if from_user:
from_user = await self.client.get_input_entity(from_user)
if not isinstance(from_user, (
types.InputPeerUser, types.InputPeerSelf)):
ty = helpers._entity_type(from_user)
if ty != helpers._EntityType.USER:
from_user = None # Ignore from_user unless it's a user
if from_user:
@@ -86,8 +86,8 @@ class _MessagesIter(RequestIter):
filter = types.InputMessagesFilterEmpty()
# Telegram completely ignores `from_id` in private chats
if isinstance(
self.entity, (types.InputPeerUser, types.InputPeerSelf)):
ty = helpers._entity_type(self.entity)
if ty == helpers._EntityType.USER:
# Don't bother sending `from_user` (it's ignored anyway),
# but keep `from_id` defined above to check it locally.
from_user = None
@@ -246,6 +246,7 @@ class _IDsIter(RequestIter):
self._ids = list(reversed(ids)) if self.reverse else ids
self._offset = 0
self._entity = (await self.client.get_input_entity(entity)) if entity else None
self._ty = helpers._EntityType(self._entity) if self._entity else None
# 30s flood wait every 300 messages (3 requests of 100 each, 30 of 10, etc.)
if self.wait_time is None:
@@ -259,7 +260,7 @@ class _IDsIter(RequestIter):
self._offset += _MAX_CHUNK_SIZE
from_id = None # By default, no need to validate from_id
if isinstance(self._entity, (types.InputChannel, types.InputPeerChannel)):
if self._ty == helpers._EntityType.CHANNEL:
try:
r = await self.client(
functions.channels.GetMessagesRequest(self._entity, ids))
@@ -1108,7 +1109,7 @@ class MessageMethods:
)
entity = await self.get_input_entity(entity) if entity else None
if isinstance(entity, types.InputPeerChannel):
if helpers._entity_type(entity) == helpers._EntityType.CHANNEL:
return await self([functions.channels.DeleteMessagesRequest(
entity, list(c)) for c in utils.chunks(message_ids)])
else:
@@ -1181,7 +1182,7 @@ class MessageMethods:
return True
if max_id is not None:
if isinstance(entity, types.InputPeerChannel):
if helpers._entity_type(entity) == helpers._EntityType.CHANNEL:
return await self(functions.channels.ReadHistoryRequest(
utils.get_input_channel(entity), max_id=max_id))
else:

View File

@@ -4,7 +4,7 @@ import itertools
import time
import typing
from .. import errors, utils, hints
from .. import errors, helpers, utils, hints
from ..errors import MultiError, RPCError
from ..helpers import retry_range
from ..tl import TLRequest, types, functions
@@ -258,12 +258,20 @@ class UserMethods:
else:
inputs.append(await self.get_input_entity(x))
users = [x for x in inputs
if isinstance(x, (types.InputPeerUser, types.InputPeerSelf))]
chats = [x.chat_id for x in inputs
if isinstance(x, types.InputPeerChat)]
channels = [x for x in inputs
if isinstance(x, types.InputPeerChannel)]
lists = {
helpers._EntityType.USER: [],
helpers._EntityType.CHAT: [],
helpers._EntityType.CHANNEL: [],
}
for x in inputs:
try:
lists[helpers._entity_type(x)].append(x)
except TypeError:
pass
users = lists[helpers._EntityType.USER]
chats = lists[helpers._EntityType.CHAT]
channels = lists[helpers._EntityType.CHANNEL]
if users:
# GetUsersRequest has a limit of 200 per call
tmp = []