Update code to deal with the new sessions

This commit is contained in:
Lonami Exo
2021-09-19 16:38:11 +02:00
parent 1f5722c925
commit 81b4957d9b
10 changed files with 173 additions and 119 deletions

View File

@@ -3,6 +3,7 @@ import itertools
from .._misc import utils
from .. import _tl
from ..sessions.types import Entity
# Which updates have the following fields?
_has_field = {
@@ -51,27 +52,60 @@ class EntityCache:
"""
In-memory input entity cache, defaultdict-like behaviour.
"""
def add(self, entities):
def add(self, entities, _mappings={
_tl.User.CONSTRUCTOR_ID: lambda e: (Entity.BOT if e.bot else Entity.USER, e.id, e.access_hash),
_tl.UserFull.CONSTRUCTOR_ID: lambda e: (Entity.BOT if e.user.bot else Entity.USER, e.user.id, e.user.access_hash),
_tl.Chat.CONSTRUCTOR_ID: lambda e: (Entity.GROUP, e.id, 0),
_tl.ChatFull.CONSTRUCTOR_ID: lambda e: (Entity.GROUP, e.id, 0),
_tl.ChatEmpty.CONSTRUCTOR_ID: lambda e: (Entity.GROUP, e.id, 0),
_tl.ChatForbidden.CONSTRUCTOR_ID: lambda e: (Entity.GROUP, e.id, 0),
_tl.Channel.CONSTRUCTOR_ID: lambda e: (
Entity.MEGAGROUP if e.megagroup else (Entity.GIGAGROUP if e.gigagroup else Entity.CHANNEL),
e.id,
e.access_hash,
),
_tl.ChannelForbidden.CONSTRUCTOR_ID: lambda e: (Entity.MEGAGROUP if e.megagroup else Entity.CHANNEL, e.id, e.access_hash),
}):
"""
Adds the given entities to the cache, if they weren't saved before.
Returns a list of Entity that can be saved in the session.
"""
if not utils.is_list_like(entities):
# Invariant: all "chats" and "users" are always iterables,
# and "user" never is (so we wrap it inside a list).
# and "user" and "chat" never are (so we wrap them inside a list).
#
# Itself may be already the entity we want to cache.
entities = itertools.chain(
[entities],
getattr(entities, 'chats', []),
getattr(entities, 'users', []),
(hasattr(entities, 'user') and [entities.user]) or []
(hasattr(entities, 'user') and [entities.user]) or [],
(hasattr(entities, 'chat') and [entities.user]) or [],
)
for entity in entities:
rows = []
for e in entities:
try:
pid = utils.get_peer_id(entity)
if pid not in self.__dict__:
# Note: `get_input_peer` already checks for `access_hash`
self.__dict__[pid] = utils.get_input_peer(entity)
except TypeError:
pass
mapper = _mappings[e.CONSTRUCTOR_ID]
except (AttributeError, KeyError):
continue
ty, id, access_hash = mapper(e)
# Need to check for non-zero access hash unless it's a group (#354 and #392).
# Also check it's not `min` (`access_hash` usage is limited since layer 102).
if not getattr(e, 'min', False) and (access_hash or ty == Entity.GROUP):
rows.append(Entity(ty, id, access_hash))
if id not in self.__dict__:
if ty in (Entity.USER, Entity.BOT):
self.__dict__[id] = _tl.InputPeerUser(id, access_hash)
elif ty in (Entity.GROUP):
self.__dict__[id] = _tl.InputPeerChat(id)
elif ty in (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP):
self.__dict__[id] = _tl.InputPeerChannel(id, access_hash)
return rows
def __getitem__(self, item):
"""