mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-08-11 03:09:31 +00:00
Use async def everywhere
This commit is contained in:
@@ -17,21 +17,21 @@ from ..tl.functions import (
|
||||
)
|
||||
|
||||
|
||||
def do_authentication(connection, retries=5):
|
||||
async def do_authentication(connection, retries=5):
|
||||
if not retries or retries < 0:
|
||||
retries = 1
|
||||
|
||||
last_error = None
|
||||
while retries:
|
||||
try:
|
||||
return _do_authentication(connection)
|
||||
return await _do_authentication(connection)
|
||||
except (SecurityError, AssertionError, NotImplementedError) as e:
|
||||
last_error = e
|
||||
retries -= 1
|
||||
raise last_error
|
||||
|
||||
|
||||
def _do_authentication(connection):
|
||||
async def _do_authentication(connection):
|
||||
"""Executes the authentication process with the Telegram servers.
|
||||
If no error is raised, returns both the authorization key and the
|
||||
time offset.
|
||||
@@ -42,8 +42,8 @@ def _do_authentication(connection):
|
||||
req_pq_request = ReqPqRequest(
|
||||
nonce=int.from_bytes(os.urandom(16), 'big', signed=True)
|
||||
)
|
||||
sender.send(req_pq_request.to_bytes())
|
||||
with BinaryReader(sender.receive()) as reader:
|
||||
await sender.send(req_pq_request.to_bytes())
|
||||
with BinaryReader(await sender.receive()) as reader:
|
||||
req_pq_request.on_response(reader)
|
||||
|
||||
res_pq = req_pq_request.result
|
||||
@@ -90,10 +90,10 @@ def _do_authentication(connection):
|
||||
public_key_fingerprint=target_fingerprint,
|
||||
encrypted_data=cipher_text
|
||||
)
|
||||
sender.send(req_dh_params.to_bytes())
|
||||
await sender.send(req_dh_params.to_bytes())
|
||||
|
||||
# Step 2 response: DH Exchange
|
||||
with BinaryReader(sender.receive()) as reader:
|
||||
with BinaryReader(await sender.receive()) as reader:
|
||||
req_dh_params.on_response(reader)
|
||||
|
||||
server_dh_params = req_dh_params.result
|
||||
@@ -157,10 +157,10 @@ def _do_authentication(connection):
|
||||
server_nonce=res_pq.server_nonce,
|
||||
encrypted_data=client_dh_encrypted,
|
||||
)
|
||||
sender.send(set_client_dh.to_bytes())
|
||||
await sender.send(set_client_dh.to_bytes())
|
||||
|
||||
# Step 3 response: Complete DH Exchange
|
||||
with BinaryReader(sender.receive()) as reader:
|
||||
with BinaryReader(await sender.receive()) as reader:
|
||||
set_client_dh.on_response(reader)
|
||||
|
||||
dh_gen = set_client_dh.result
|
||||
|
@@ -1,14 +1,13 @@
|
||||
import errno
|
||||
import os
|
||||
import struct
|
||||
from datetime import timedelta
|
||||
from zlib import crc32
|
||||
from enum import Enum
|
||||
|
||||
import errno
|
||||
from zlib import crc32
|
||||
|
||||
from ..crypto import AESModeCTR
|
||||
from ..extensions import TcpClient
|
||||
from ..errors import InvalidChecksumError
|
||||
from ..extensions import TcpClient
|
||||
|
||||
|
||||
class ConnectionMode(Enum):
|
||||
@@ -74,9 +73,9 @@ class Connection:
|
||||
setattr(self, 'write', self._write_plain)
|
||||
setattr(self, 'read', self._read_plain)
|
||||
|
||||
def connect(self, ip, port):
|
||||
async def connect(self, ip, port):
|
||||
try:
|
||||
self.conn.connect(ip, port)
|
||||
await self.conn.connect(ip, port)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EISCONN:
|
||||
return # Already connected, no need to re-set everything up
|
||||
@@ -85,16 +84,16 @@ class Connection:
|
||||
|
||||
self._send_counter = 0
|
||||
if self._mode == ConnectionMode.TCP_ABRIDGED:
|
||||
self.conn.write(b'\xef')
|
||||
await self.conn.write(b'\xef')
|
||||
elif self._mode == ConnectionMode.TCP_INTERMEDIATE:
|
||||
self.conn.write(b'\xee\xee\xee\xee')
|
||||
await self.conn.write(b'\xee\xee\xee\xee')
|
||||
elif self._mode == ConnectionMode.TCP_OBFUSCATED:
|
||||
self._setup_obfuscation()
|
||||
await self._setup_obfuscation()
|
||||
|
||||
def get_timeout(self):
|
||||
return self.conn.timeout
|
||||
|
||||
def _setup_obfuscation(self):
|
||||
async def _setup_obfuscation(self):
|
||||
# Obfuscated messages secrets cannot start with any of these
|
||||
keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4)
|
||||
while True:
|
||||
@@ -119,7 +118,7 @@ class Connection:
|
||||
self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv)
|
||||
|
||||
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
|
||||
self.conn.write(bytes(random))
|
||||
await self.conn.write(bytes(random))
|
||||
|
||||
def is_connected(self):
|
||||
return self.conn.connected
|
||||
@@ -135,20 +134,23 @@ class Connection:
|
||||
|
||||
# region Receive message implementations
|
||||
|
||||
def recv(self):
|
||||
async def recv(self):
|
||||
"""Receives and unpacks a message"""
|
||||
# Default implementation is just an error
|
||||
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
|
||||
|
||||
def _recv_tcp_full(self):
|
||||
packet_length_bytes = self.read(4)
|
||||
async def _recv_tcp_full(self):
|
||||
# TODO We don't want another call to this method that could
|
||||
# potentially await on another self.read(n). Is this guaranteed
|
||||
# by asyncio?
|
||||
packet_length_bytes = await self.read(4)
|
||||
packet_length = int.from_bytes(packet_length_bytes, 'little')
|
||||
|
||||
seq_bytes = self.read(4)
|
||||
seq_bytes = await self.read(4)
|
||||
seq = int.from_bytes(seq_bytes, 'little')
|
||||
|
||||
body = self.read(packet_length - 12)
|
||||
checksum = int.from_bytes(self.read(4), 'little')
|
||||
body = await self.read(packet_length - 12)
|
||||
checksum = int.from_bytes(await self.read(4), 'little')
|
||||
|
||||
valid_checksum = crc32(packet_length_bytes + seq_bytes + body)
|
||||
if checksum != valid_checksum:
|
||||
@@ -156,72 +158,70 @@ class Connection:
|
||||
|
||||
return body
|
||||
|
||||
def _recv_intermediate(self):
|
||||
return self.read(int.from_bytes(self.read(4), 'little'))
|
||||
async def _recv_intermediate(self):
|
||||
return await self.read(int.from_bytes(self.read(4), 'little'))
|
||||
|
||||
def _recv_abridged(self):
|
||||
async def _recv_abridged(self):
|
||||
length = int.from_bytes(self.read(1), 'little')
|
||||
if length >= 127:
|
||||
length = int.from_bytes(self.read(3) + b'\0', 'little')
|
||||
|
||||
return self.read(length << 2)
|
||||
return await self.read(length << 2)
|
||||
|
||||
# endregion
|
||||
|
||||
# region Send message implementations
|
||||
|
||||
def send(self, message):
|
||||
async def send(self, message):
|
||||
"""Encapsulates and sends the given message"""
|
||||
# Default implementation is just an error
|
||||
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
|
||||
|
||||
def _send_tcp_full(self, message):
|
||||
async def _send_tcp_full(self, message):
|
||||
# https://core.telegram.org/mtproto#tcp-transport
|
||||
# total length, sequence number, packet and checksum (CRC32)
|
||||
length = len(message) + 12
|
||||
data = struct.pack('<ii', length, self._send_counter) + message
|
||||
crc = struct.pack('<I', crc32(data))
|
||||
self._send_counter += 1
|
||||
self.write(data + crc)
|
||||
await self.write(data + crc)
|
||||
|
||||
def _send_intermediate(self, message):
|
||||
self.write(struct.pack('<i', len(message)) + message)
|
||||
async def _send_intermediate(self, message):
|
||||
await self.write(struct.pack('<i', len(message)) + message)
|
||||
|
||||
def _send_abridged(self, message):
|
||||
async def _send_abridged(self, message):
|
||||
length = len(message) >> 2
|
||||
if length < 127:
|
||||
length = struct.pack('B', length)
|
||||
else:
|
||||
length = b'\x7f' + int.to_bytes(length, 3, 'little')
|
||||
|
||||
self.write(length + message)
|
||||
await self.write(length + message)
|
||||
|
||||
# endregion
|
||||
|
||||
# region Read implementations
|
||||
|
||||
def read(self, length):
|
||||
async def read(self, length):
|
||||
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
|
||||
|
||||
def _read_plain(self, length):
|
||||
return self.conn.read(length)
|
||||
async def _read_plain(self, length):
|
||||
return await self.conn.read(length)
|
||||
|
||||
def _read_obfuscated(self, length):
|
||||
return self._aes_decrypt.encrypt(
|
||||
self.conn.read(length)
|
||||
)
|
||||
async def _read_obfuscated(self, length):
|
||||
return await self._aes_decrypt.encrypt(self.conn.read(length))
|
||||
|
||||
# endregion
|
||||
|
||||
# region Write implementations
|
||||
|
||||
def write(self, data):
|
||||
async def write(self, data):
|
||||
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
|
||||
|
||||
def _write_plain(self, data):
|
||||
self.conn.write(data)
|
||||
async def _write_plain(self, data):
|
||||
await self.conn.write(data)
|
||||
|
||||
def _write_obfuscated(self, data):
|
||||
self.conn.write(self._aes_encrypt.encrypt(data))
|
||||
async def _write_obfuscated(self, data):
|
||||
await self.conn.write(self._aes_encrypt.encrypt(data))
|
||||
|
||||
# endregion
|
||||
|
@@ -16,23 +16,23 @@ class MtProtoPlainSender:
|
||||
self._last_msg_id = 0
|
||||
self._connection = connection
|
||||
|
||||
def connect(self):
|
||||
self._connection.connect()
|
||||
async def connect(self):
|
||||
await self._connection.connect()
|
||||
|
||||
def disconnect(self):
|
||||
self._connection.close()
|
||||
|
||||
def send(self, data):
|
||||
async def send(self, data):
|
||||
"""Sends a plain packet (auth_key_id = 0) containing the
|
||||
given message body (data)
|
||||
"""
|
||||
self._connection.send(
|
||||
await self._connection.send(
|
||||
struct.pack('<QQi', 0, self._get_new_msg_id(), len(data)) + data
|
||||
)
|
||||
|
||||
def receive(self):
|
||||
async def receive(self):
|
||||
"""Receives a plain packet, returning the body of the response"""
|
||||
body = self._connection.recv()
|
||||
body = await self._connection.recv()
|
||||
if body == b'l\xfe\xff\xff': # -404 little endian signed
|
||||
# Broken authorization, must reset the auth key
|
||||
raise BrokenAuthKeyError()
|
||||
|
@@ -41,9 +41,9 @@ class MtProtoSender:
|
||||
# Requests (as msg_id: Message) sent waiting to be received
|
||||
self._pending_receive = {}
|
||||
|
||||
def connect(self):
|
||||
async def connect(self):
|
||||
"""Connects to the server"""
|
||||
self.connection.connect(self.session.server_address, self.session.port)
|
||||
await self.connection.connect(self.session.server_address, self.session.port)
|
||||
|
||||
def is_connected(self):
|
||||
return self.connection.is_connected()
|
||||
@@ -60,7 +60,7 @@ class MtProtoSender:
|
||||
|
||||
# region Send and receive
|
||||
|
||||
def send(self, *requests):
|
||||
async def send(self, *requests):
|
||||
"""Sends the specified MTProtoRequest, previously sending any message
|
||||
which needed confirmation."""
|
||||
|
||||
@@ -80,13 +80,13 @@ class MtProtoSender:
|
||||
else:
|
||||
message = TLMessage(self.session, MessageContainer(messages))
|
||||
|
||||
self._send_message(message)
|
||||
await self._send_message(message)
|
||||
|
||||
def _send_acknowledge(self, msg_id):
|
||||
async def _send_acknowledge(self, msg_id):
|
||||
"""Sends a message acknowledge for the given msg_id"""
|
||||
self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
|
||||
await self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
|
||||
|
||||
def receive(self, update_state):
|
||||
async def receive(self, update_state):
|
||||
"""Receives a single message from the connected endpoint.
|
||||
|
||||
This method returns nothing, and will only affect other parts
|
||||
@@ -97,7 +97,7 @@ class MtProtoSender:
|
||||
update_state.process(TLObject).
|
||||
"""
|
||||
try:
|
||||
body = self.connection.recv()
|
||||
body = await self.connection.recv()
|
||||
except (BufferError, InvalidChecksumError):
|
||||
# TODO BufferError, we should spot the cause...
|
||||
# "No more bytes left"; something wrong happened, clear
|
||||
@@ -111,13 +111,13 @@ class MtProtoSender:
|
||||
|
||||
message, remote_msg_id, remote_seq = self._decode_msg(body)
|
||||
with BinaryReader(message) as reader:
|
||||
self._process_msg(remote_msg_id, remote_seq, reader, update_state)
|
||||
await self._process_msg(remote_msg_id, remote_seq, reader, update_state)
|
||||
|
||||
# endregion
|
||||
|
||||
# region Low level processing
|
||||
|
||||
def _send_message(self, message):
|
||||
async def _send_message(self, message):
|
||||
"""Sends the given Message(TLObject) encrypted through the network"""
|
||||
|
||||
plain_text = \
|
||||
@@ -130,7 +130,7 @@ class MtProtoSender:
|
||||
cipher_text = AES.encrypt_ige(plain_text, key, iv)
|
||||
|
||||
result = key_id + msg_key + cipher_text
|
||||
self.connection.send(result)
|
||||
await self.connection.send(result)
|
||||
|
||||
def _decode_msg(self, body):
|
||||
"""Decodes an received encrypted message body bytes"""
|
||||
@@ -163,7 +163,7 @@ class MtProtoSender:
|
||||
|
||||
return message, remote_msg_id, remote_sequence
|
||||
|
||||
def _process_msg(self, msg_id, sequence, reader, state):
|
||||
async def _process_msg(self, msg_id, sequence, reader, state):
|
||||
"""Processes and handles a Telegram message.
|
||||
|
||||
Returns True if the message was handled correctly and doesn't
|
||||
@@ -178,22 +178,22 @@ class MtProtoSender:
|
||||
|
||||
# The following codes are "parsed manually"
|
||||
if code == 0xf35c6d01: # rpc_result, (response of an RPC call)
|
||||
return self._handle_rpc_result(msg_id, sequence, reader)
|
||||
return await self._handle_rpc_result(msg_id, sequence, reader)
|
||||
|
||||
if code == 0x347773c5: # pong
|
||||
return self._handle_pong(msg_id, sequence, reader)
|
||||
return await self._handle_pong(msg_id, sequence, reader)
|
||||
|
||||
if code == 0x73f1f8dc: # msg_container
|
||||
return self._handle_container(msg_id, sequence, reader, state)
|
||||
return await self._handle_container(msg_id, sequence, reader, state)
|
||||
|
||||
if code == 0x3072cfa1: # gzip_packed
|
||||
return self._handle_gzip_packed(msg_id, sequence, reader, state)
|
||||
return await self._handle_gzip_packed(msg_id, sequence, reader, state)
|
||||
|
||||
if code == 0xedab447b: # bad_server_salt
|
||||
return self._handle_bad_server_salt(msg_id, sequence, reader)
|
||||
return await self._handle_bad_server_salt(msg_id, sequence, reader)
|
||||
|
||||
if code == 0xa7eff811: # bad_msg_notification
|
||||
return self._handle_bad_msg_notification(msg_id, sequence, reader)
|
||||
return await self._handle_bad_msg_notification(msg_id, sequence, reader)
|
||||
|
||||
# msgs_ack, it may handle the request we wanted
|
||||
if code == 0x62d6b459:
|
||||
@@ -247,7 +247,7 @@ class MtProtoSender:
|
||||
r.confirm_received.set()
|
||||
self._pending_receive.clear()
|
||||
|
||||
def _handle_pong(self, msg_id, sequence, reader):
|
||||
async def _handle_pong(self, msg_id, sequence, reader):
|
||||
self._logger.debug('Handling pong')
|
||||
reader.read_int(signed=False) # code
|
||||
received_msg_id = reader.read_long()
|
||||
@@ -259,7 +259,7 @@ class MtProtoSender:
|
||||
|
||||
return True
|
||||
|
||||
def _handle_container(self, msg_id, sequence, reader, state):
|
||||
async def _handle_container(self, msg_id, sequence, reader, state):
|
||||
self._logger.debug('Handling container')
|
||||
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
|
||||
begin_position = reader.tell_position()
|
||||
@@ -267,7 +267,7 @@ class MtProtoSender:
|
||||
# Note that this code is IMPORTANT for skipping RPC results of
|
||||
# lost requests (i.e., ones from the previous connection session)
|
||||
try:
|
||||
if not self._process_msg(inner_msg_id, sequence, reader, state):
|
||||
if not await self._process_msg(inner_msg_id, sequence, reader, state):
|
||||
reader.set_position(begin_position + inner_len)
|
||||
except:
|
||||
# If any error is raised, something went wrong; skip the packet
|
||||
@@ -276,7 +276,7 @@ class MtProtoSender:
|
||||
|
||||
return True
|
||||
|
||||
def _handle_bad_server_salt(self, msg_id, sequence, reader):
|
||||
async def _handle_bad_server_salt(self, msg_id, sequence, reader):
|
||||
self._logger.debug('Handling bad server salt')
|
||||
reader.read_int(signed=False) # code
|
||||
bad_msg_id = reader.read_long()
|
||||
@@ -287,11 +287,11 @@ class MtProtoSender:
|
||||
|
||||
request = self._pop_request(bad_msg_id)
|
||||
if request:
|
||||
self.send(request)
|
||||
await self.send(request)
|
||||
|
||||
return True
|
||||
|
||||
def _handle_bad_msg_notification(self, msg_id, sequence, reader):
|
||||
async def _handle_bad_msg_notification(self, msg_id, sequence, reader):
|
||||
self._logger.debug('Handling bad message notification')
|
||||
reader.read_int(signed=False) # code
|
||||
reader.read_long() # request_id
|
||||
@@ -318,7 +318,7 @@ class MtProtoSender:
|
||||
else:
|
||||
raise error
|
||||
|
||||
def _handle_rpc_result(self, msg_id, sequence, reader):
|
||||
async def _handle_rpc_result(self, msg_id, sequence, reader):
|
||||
self._logger.debug('Handling RPC result')
|
||||
reader.read_int(signed=False) # code
|
||||
request_id = reader.read_long()
|
||||
@@ -338,7 +338,7 @@ class MtProtoSender:
|
||||
)
|
||||
|
||||
# Acknowledge that we received the error
|
||||
self._send_acknowledge(request_id)
|
||||
await self._send_acknowledge(request_id)
|
||||
|
||||
if request:
|
||||
request.rpc_error = error
|
||||
@@ -366,9 +366,9 @@ class MtProtoSender:
|
||||
self._logger.debug('Lost request will be skipped.')
|
||||
return False
|
||||
|
||||
def _handle_gzip_packed(self, msg_id, sequence, reader, state):
|
||||
async def _handle_gzip_packed(self, msg_id, sequence, reader, state):
|
||||
self._logger.debug('Handling gzip packed data')
|
||||
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
|
||||
return self._process_msg(msg_id, sequence, compressed_reader, state)
|
||||
return await self._process_msg(msg_id, sequence, compressed_reader, state)
|
||||
|
||||
# endregion
|
||||
|
Reference in New Issue
Block a user