Remove the aggressive hack from get_participants

This commit is contained in:
Lonami Exo
2021-09-17 20:13:05 +02:00
parent be3ed894c6
commit 1036c3cb52
3 changed files with 46 additions and 76 deletions

View File

@@ -94,7 +94,7 @@ class _ChatAction:
class _ParticipantsIter(requestiter.RequestIter):
async def _init(self, entity, filter, search, aggressive):
async def _init(self, entity, filter, search):
if isinstance(filter, type):
if filter in (_tl.ChannelParticipantsBanned,
_tl.ChannelParticipantsKicked,
@@ -118,9 +118,6 @@ class _ParticipantsIter(requestiter.RequestIter):
else:
self.filter_entity = lambda ent: True
# Only used for channels, but we should always set the attribute
self.requests = []
if ty == helpers._EntityType.CHANNEL:
if self.limit <= 0:
# May not have access to the channel, but getFull can get the .total.
@@ -130,22 +127,13 @@ class _ParticipantsIter(requestiter.RequestIter):
raise StopAsyncIteration
self.seen = set()
if aggressive and not filter:
self.requests.extend(_tl.fn.channels.GetParticipants(
channel=entity,
filter=_tl.ChannelParticipantsSearch(x),
offset=0,
limit=_MAX_PARTICIPANTS_CHUNK_SIZE,
hash=0
) for x in (search or string.ascii_lowercase))
else:
self.requests.append(_tl.fn.channels.GetParticipants(
channel=entity,
filter=filter or _tl.ChannelParticipantsSearch(search),
offset=0,
limit=_MAX_PARTICIPANTS_CHUNK_SIZE,
hash=0
))
self.request = _tl.fn.channels.GetParticipants(
channel=entity,
filter=filter or _tl.ChannelParticipantsSearch(search),
offset=0,
limit=_MAX_PARTICIPANTS_CHUNK_SIZE,
hash=0
)
elif ty == helpers._EntityType.CHAT:
full = await self.client(
@@ -184,24 +172,21 @@ class _ParticipantsIter(requestiter.RequestIter):
return True
async def _load_next_chunk(self):
if not self.requests:
return True
# Only care about the limit for the first request
# (small amount of people, won't be aggressive).
# (small amount of people).
#
# Most people won't care about getting exactly 12,345
# members so it doesn't really matter not to be 100%
# precise with being out of the offset/limit here.
self.requests[0].limit = min(
self.limit - self.requests[0].offset, _MAX_PARTICIPANTS_CHUNK_SIZE)
self.request.limit = min(
self.limit - self.request.offset, _MAX_PARTICIPANTS_CHUNK_SIZE)
if self.requests[0].offset > self.limit:
if self.request.offset > self.limit:
return True
if self.total is None:
f = self.requests[0].filter
if len(self.requests) > 1 or (
f = self.request.filter
if (
not isinstance(f, _tl.ChannelParticipantsRecent)
and (not isinstance(f, _tl.ChannelParticipantsSearch) or f.q)
):
@@ -209,42 +194,36 @@ class _ParticipantsIter(requestiter.RequestIter):
# if there's a filter which would reduce the real total number.
# getParticipants is cheaper than getFull.
self.total = (await self.client(_tl.fn.channels.GetParticipants(
channel=self.requests[0].channel,
channel=self.request.channel,
filter=_tl.ChannelParticipantsRecent(),
offset=0,
limit=1,
hash=0
))).count
results = await self.client(self.requests)
for i in reversed(range(len(self.requests))):
participants = results[i]
if self.total is None:
# Will only get here if there was one request with a filter that matched all users.
self.total = participants.count
if not participants.users:
self.requests.pop(i)
continue
participants = await self.client(self.request)
if self.total is None:
# Will only get here if there was one request with a filter that matched all users.
self.total = participants.count
self.requests[i].offset += len(participants.participants)
users = {user.id: user for user in participants.users}
for participant in participants.participants:
if isinstance(participant, _tl.ChannelParticipantBanned):
if not isinstance(participant.peer, _tl.PeerUser):
# May have the entire channel banned. See #3105.
continue
user_id = participant.peer.user_id
else:
user_id = participant.user_id
user = users[user_id]
if not self.filter_entity(user) or user.id in self.seen:
self.request.offset += len(participants.participants)
users = {user.id: user for user in participants.users}
for participant in participants.participants:
if isinstance(participant, _tl.ChannelParticipantBanned):
if not isinstance(participant.peer, _tl.PeerUser):
# May have the entire channel banned. See #3105.
continue
self.seen.add(user_id)
user = users[user_id]
user.participant = participant
self.buffer.append(user)
user_id = participant.peer.user_id
else:
user_id = participant.user_id
user = users[user_id]
if not self.filter_entity(user) or user.id in self.seen:
continue
self.seen.add(user_id)
user = users[user_id]
user.participant = participant
self.buffer.append(user)
class _AdminLogIter(requestiter.RequestIter):
@@ -407,15 +386,13 @@ def get_participants(
limit: float = None,
*,
search: str = '',
filter: '_tl.TypeChannelParticipantsFilter' = None,
aggressive: bool = False) -> _ParticipantsIter:
filter: '_tl.TypeChannelParticipantsFilter' = None) -> _ParticipantsIter:
return _ParticipantsIter(
self,
limit,
entity=entity,
filter=filter,
search=search,
aggressive=aggressive
search=search
)