Update code to deal with the new sessions

This commit is contained in:
Lonami Exo
2021-09-19 16:38:11 +02:00
parent 1f5722c925
commit 81b4957d9b
10 changed files with 173 additions and 119 deletions

View File

@@ -6,12 +6,14 @@ import logging
import platform
import time
import typing
import ipaddress
from .. import version, helpers, __name__ as __base_name__, _tl
from .._crypto import rsa
from .._misc import markdown, entitycache, statecache, enums
from .._network import MTProtoSender, Connection, ConnectionTcpFull, connection as conns
from ..sessions import Session, SQLiteSession, MemorySession
from ..sessions.types import DataCenter, SessionState
DEFAULT_DC_ID = 2
DEFAULT_IPV4_IP = '149.154.167.51'
@@ -129,15 +131,6 @@ def init(
'The given session must be a str or a Session instance.'
)
# ':' in session.server_address is True if it's an IPv6 address
if (not session.server_address or
(':' in session.server_address) != use_ipv6):
session.set_dc(
DEFAULT_DC_ID,
DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP,
DEFAULT_PORT
)
self.flood_sleep_threshold = flood_sleep_threshold
# TODO Use AsyncClassWrapper(session)
@@ -230,13 +223,11 @@ def init(
)
self._sender = MTProtoSender(
self.session.auth_key,
loggers=self._log,
retries=self._connection_retries,
delay=self._retry_delay,
auto_reconnect=self._auto_reconnect,
connect_timeout=self._timeout,
auth_key_callback=self._auth_key_callback,
update_callback=self._handle_update,
auto_reconnect_callback=self._handle_auto_reconnect
)
@@ -264,11 +255,6 @@ def init(
self._authorized = None # None = unknown, False = no, True = yes
# Update state (for catching up after a disconnection)
# TODO Get state from channels too
self._state_cache = statecache.StateCache(
self.session.get_update_state(0), self._log)
# Some further state for subclasses
self._event_builders = []
@@ -310,10 +296,33 @@ def set_flood_sleep_threshold(self, value):
async def connect(self: 'TelegramClient') -> None:
all_dc = await self.session.get_all_dc()
state = await self.session.get_state()
dc = None
if state:
for d in all_dc:
if d.id == state.dc_id:
dc = d
break
if dc is None:
dc = DataCenter(
id=DEFAULT_DC_ID,
ipv4=None if self._use_ipv6 else int(ipaddress.ip_address(DEFAULT_IPV4_IP)),
ipv6=int(ipaddress.ip_address(DEFAULT_IPV6_IP)) if self._use_ipv6 else None,
port=DEFAULT_PORT,
auth=b'',
)
# Update state (for catching up after a disconnection)
# TODO Get state from channels too
self._state_cache = statecache.StateCache(state, self._log)
if not await self._sender.connect(self._connection(
self.session.server_address,
self.session.port,
self.session.dc_id,
str(ipaddress.ip_address(dc.ipv6 or dc.ipv4)),
dc.port,
dc.id,
loggers=self._log,
proxy=self._proxy,
local_addr=self._local_addr
@@ -321,8 +330,10 @@ async def connect(self: 'TelegramClient') -> None:
# We don't want to init or modify anything if we were already connected
return
self.session.auth_key = self._sender.auth_key
self.session.save()
if self._sender.auth_key.key != dc.key:
dc.key = self._sender.auth_key.key
await self.session.insert_dc(dc)
await self.session.save()
self._init_request.query = _tl.fn.help.GetConfig()
@@ -388,15 +399,12 @@ async def _disconnect_coro(self: 'TelegramClient'):
pts, date = self._state_cache[None]
if pts and date:
self.session.set_update_state(0, _tl.updates.State(
pts=pts,
qts=0,
date=date,
seq=0,
unread_count=0
))
self.session.close()
state = await self.session.get_state()
if state:
state.pts = pts
state.date = date
await self.session.set_state(state)
await self.session.save()
async def _disconnect(self: 'TelegramClient'):
"""
@@ -414,31 +422,59 @@ async def _switch_dc(self: 'TelegramClient', new_dc):
Permanently switches the current connection to the new data center.
"""
self._log[__name__].info('Reconnecting to new data center %s', new_dc)
dc = await _get_dc(self, new_dc)
dc = await _refresh_and_get_dc(self, new_dc)
state = await self.session.get_state()
if state is None:
state = SessionState(
user_id=0,
dc_id=dc.id,
bot=False,
pts=0,
qts=0,
date=0,
seq=0,
takeout_id=None,
)
else:
state.dc_id = dc.id
await self.session.set_state(dc.id)
await self.session.save()
self.session.set_dc(dc.id, dc.ip_address, dc.port)
# auth_key's are associated with a server, which has now changed
# so it's not valid anymore. Set to None to force recreating it.
self._sender.auth_key.key = None
self.session.auth_key = None
self.session.save()
await _disconnect(self)
return await self.connect()
def _auth_key_callback(self: 'TelegramClient', auth_key):
"""
Callback from the sender whenever it needed to generate a
new authorization key. This means we are not authorized.
"""
self.session.auth_key = auth_key
self.session.save()
async def _refresh_and_get_dc(self: 'TelegramClient', dc_id):
"""
Gets the Data Center (DC) associated to `dc_id`.
async def _get_dc(self: 'TelegramClient', dc_id):
"""Gets the Data Center (DC) associated to 'dc_id'"""
Also take this opportunity to refresh the addresses stored in the session if needed.
"""
cls = self.__class__
if not cls._config:
cls._config = await self(_tl.fn.help.GetConfig())
all_dc = {dc.id: dc for dc in await self.session.get_all_dc()}
for dc in cls._config.dc_options:
if dc.media_only or dc.tcpo_only or dc.cdn:
continue
ip = int(ipaddress.ip_address(dc.ip_address))
if dc.id in all_dc:
all_dc[dc.id].port = dc.port
if dc.ipv6:
all_dc[dc.id].ipv6 = ip
else:
all_dc[dc.id].ipv4 = ip
elif dc.ipv6:
all_dc[dc.id] = DataCenter(dc.id, None, ip, dc.port, b'')
else:
all_dc[dc.id] = DataCenter(dc.id, ip, None, dc.port, b'')
for dc in all_dc.values():
await self.session.insert_dc(dc)
await self.session.save()
try:
return next(
@@ -463,12 +499,12 @@ async def _create_exported_sender(self: 'TelegramClient', dc_id):
"""
# Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt
# for clearly showing how to export the authorization
dc = await _get_dc(self, dc_id)
dc = await _refresh_and_get_dc(self, dc_id)
# Can't reuse self._sender._connection as it has its own seqno.
#
# If one were to do that, Telegram would reset the connection
# with no further clues.
sender = MTProtoSender(None, loggers=self._log)
sender = MTProtoSender(loggers=self._log)
await sender.connect(self._connection(
dc.ip_address,
dc.port,
@@ -503,7 +539,7 @@ async def _borrow_exported_sender(self: 'TelegramClient', dc_id):
self._borrowed_senders[dc_id] = (state, sender)
elif state.need_connect():
dc = await _get_dc(self, dc_id)
dc = await _refresh_and_get_dc(self, dc_id)
await sender.connect(self._connection(
dc.ip_address,
dc.port,