Use async def everywhere

This commit is contained in:
Lonami Exo
2017-10-06 21:02:41 +02:00
parent 9716d1d543
commit 77c99db066
7 changed files with 206 additions and 208 deletions

View File

@@ -17,21 +17,21 @@ from ..tl.functions import (
)
def do_authentication(connection, retries=5):
async def do_authentication(connection, retries=5):
if not retries or retries < 0:
retries = 1
last_error = None
while retries:
try:
return _do_authentication(connection)
return await _do_authentication(connection)
except (SecurityError, AssertionError, NotImplementedError) as e:
last_error = e
retries -= 1
raise last_error
def _do_authentication(connection):
async def _do_authentication(connection):
"""Executes the authentication process with the Telegram servers.
If no error is raised, returns both the authorization key and the
time offset.
@@ -42,8 +42,8 @@ def _do_authentication(connection):
req_pq_request = ReqPqRequest(
nonce=int.from_bytes(os.urandom(16), 'big', signed=True)
)
sender.send(req_pq_request.to_bytes())
with BinaryReader(sender.receive()) as reader:
await sender.send(req_pq_request.to_bytes())
with BinaryReader(await sender.receive()) as reader:
req_pq_request.on_response(reader)
res_pq = req_pq_request.result
@@ -90,10 +90,10 @@ def _do_authentication(connection):
public_key_fingerprint=target_fingerprint,
encrypted_data=cipher_text
)
sender.send(req_dh_params.to_bytes())
await sender.send(req_dh_params.to_bytes())
# Step 2 response: DH Exchange
with BinaryReader(sender.receive()) as reader:
with BinaryReader(await sender.receive()) as reader:
req_dh_params.on_response(reader)
server_dh_params = req_dh_params.result
@@ -157,10 +157,10 @@ def _do_authentication(connection):
server_nonce=res_pq.server_nonce,
encrypted_data=client_dh_encrypted,
)
sender.send(set_client_dh.to_bytes())
await sender.send(set_client_dh.to_bytes())
# Step 3 response: Complete DH Exchange
with BinaryReader(sender.receive()) as reader:
with BinaryReader(await sender.receive()) as reader:
set_client_dh.on_response(reader)
dh_gen = set_client_dh.result

View File

@@ -1,14 +1,13 @@
import errno
import os
import struct
from datetime import timedelta
from zlib import crc32
from enum import Enum
import errno
from zlib import crc32
from ..crypto import AESModeCTR
from ..extensions import TcpClient
from ..errors import InvalidChecksumError
from ..extensions import TcpClient
class ConnectionMode(Enum):
@@ -74,9 +73,9 @@ class Connection:
setattr(self, 'write', self._write_plain)
setattr(self, 'read', self._read_plain)
def connect(self, ip, port):
async def connect(self, ip, port):
try:
self.conn.connect(ip, port)
await self.conn.connect(ip, port)
except OSError as e:
if e.errno == errno.EISCONN:
return # Already connected, no need to re-set everything up
@@ -85,16 +84,16 @@ class Connection:
self._send_counter = 0
if self._mode == ConnectionMode.TCP_ABRIDGED:
self.conn.write(b'\xef')
await self.conn.write(b'\xef')
elif self._mode == ConnectionMode.TCP_INTERMEDIATE:
self.conn.write(b'\xee\xee\xee\xee')
await self.conn.write(b'\xee\xee\xee\xee')
elif self._mode == ConnectionMode.TCP_OBFUSCATED:
self._setup_obfuscation()
await self._setup_obfuscation()
def get_timeout(self):
return self.conn.timeout
def _setup_obfuscation(self):
async def _setup_obfuscation(self):
# Obfuscated messages secrets cannot start with any of these
keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4)
while True:
@@ -119,7 +118,7 @@ class Connection:
self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv)
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
self.conn.write(bytes(random))
await self.conn.write(bytes(random))
def is_connected(self):
return self.conn.connected
@@ -135,20 +134,23 @@ class Connection:
# region Receive message implementations
def recv(self):
async def recv(self):
"""Receives and unpacks a message"""
# Default implementation is just an error
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _recv_tcp_full(self):
packet_length_bytes = self.read(4)
async def _recv_tcp_full(self):
# TODO We don't want another call to this method that could
# potentially await on another self.read(n). Is this guaranteed
# by asyncio?
packet_length_bytes = await self.read(4)
packet_length = int.from_bytes(packet_length_bytes, 'little')
seq_bytes = self.read(4)
seq_bytes = await self.read(4)
seq = int.from_bytes(seq_bytes, 'little')
body = self.read(packet_length - 12)
checksum = int.from_bytes(self.read(4), 'little')
body = await self.read(packet_length - 12)
checksum = int.from_bytes(await self.read(4), 'little')
valid_checksum = crc32(packet_length_bytes + seq_bytes + body)
if checksum != valid_checksum:
@@ -156,72 +158,70 @@ class Connection:
return body
def _recv_intermediate(self):
return self.read(int.from_bytes(self.read(4), 'little'))
async def _recv_intermediate(self):
return await self.read(int.from_bytes(self.read(4), 'little'))
def _recv_abridged(self):
async def _recv_abridged(self):
length = int.from_bytes(self.read(1), 'little')
if length >= 127:
length = int.from_bytes(self.read(3) + b'\0', 'little')
return self.read(length << 2)
return await self.read(length << 2)
# endregion
# region Send message implementations
def send(self, message):
async def send(self, message):
"""Encapsulates and sends the given message"""
# Default implementation is just an error
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _send_tcp_full(self, message):
async def _send_tcp_full(self, message):
# https://core.telegram.org/mtproto#tcp-transport
# total length, sequence number, packet and checksum (CRC32)
length = len(message) + 12
data = struct.pack('<ii', length, self._send_counter) + message
crc = struct.pack('<I', crc32(data))
self._send_counter += 1
self.write(data + crc)
await self.write(data + crc)
def _send_intermediate(self, message):
self.write(struct.pack('<i', len(message)) + message)
async def _send_intermediate(self, message):
await self.write(struct.pack('<i', len(message)) + message)
def _send_abridged(self, message):
async def _send_abridged(self, message):
length = len(message) >> 2
if length < 127:
length = struct.pack('B', length)
else:
length = b'\x7f' + int.to_bytes(length, 3, 'little')
self.write(length + message)
await self.write(length + message)
# endregion
# region Read implementations
def read(self, length):
async def read(self, length):
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _read_plain(self, length):
return self.conn.read(length)
async def _read_plain(self, length):
return await self.conn.read(length)
def _read_obfuscated(self, length):
return self._aes_decrypt.encrypt(
self.conn.read(length)
)
async def _read_obfuscated(self, length):
return await self._aes_decrypt.encrypt(self.conn.read(length))
# endregion
# region Write implementations
def write(self, data):
async def write(self, data):
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _write_plain(self, data):
self.conn.write(data)
async def _write_plain(self, data):
await self.conn.write(data)
def _write_obfuscated(self, data):
self.conn.write(self._aes_encrypt.encrypt(data))
async def _write_obfuscated(self, data):
await self.conn.write(self._aes_encrypt.encrypt(data))
# endregion

View File

@@ -16,23 +16,23 @@ class MtProtoPlainSender:
self._last_msg_id = 0
self._connection = connection
def connect(self):
self._connection.connect()
async def connect(self):
await self._connection.connect()
def disconnect(self):
self._connection.close()
def send(self, data):
async def send(self, data):
"""Sends a plain packet (auth_key_id = 0) containing the
given message body (data)
"""
self._connection.send(
await self._connection.send(
struct.pack('<QQi', 0, self._get_new_msg_id(), len(data)) + data
)
def receive(self):
async def receive(self):
"""Receives a plain packet, returning the body of the response"""
body = self._connection.recv()
body = await self._connection.recv()
if body == b'l\xfe\xff\xff': # -404 little endian signed
# Broken authorization, must reset the auth key
raise BrokenAuthKeyError()

View File

@@ -41,9 +41,9 @@ class MtProtoSender:
# Requests (as msg_id: Message) sent waiting to be received
self._pending_receive = {}
def connect(self):
async def connect(self):
"""Connects to the server"""
self.connection.connect(self.session.server_address, self.session.port)
await self.connection.connect(self.session.server_address, self.session.port)
def is_connected(self):
return self.connection.is_connected()
@@ -60,7 +60,7 @@ class MtProtoSender:
# region Send and receive
def send(self, *requests):
async def send(self, *requests):
"""Sends the specified MTProtoRequest, previously sending any message
which needed confirmation."""
@@ -80,13 +80,13 @@ class MtProtoSender:
else:
message = TLMessage(self.session, MessageContainer(messages))
self._send_message(message)
await self._send_message(message)
def _send_acknowledge(self, msg_id):
async def _send_acknowledge(self, msg_id):
"""Sends a message acknowledge for the given msg_id"""
self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
await self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
def receive(self, update_state):
async def receive(self, update_state):
"""Receives a single message from the connected endpoint.
This method returns nothing, and will only affect other parts
@@ -97,7 +97,7 @@ class MtProtoSender:
update_state.process(TLObject).
"""
try:
body = self.connection.recv()
body = await self.connection.recv()
except (BufferError, InvalidChecksumError):
# TODO BufferError, we should spot the cause...
# "No more bytes left"; something wrong happened, clear
@@ -111,13 +111,13 @@ class MtProtoSender:
message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader:
self._process_msg(remote_msg_id, remote_seq, reader, update_state)
await self._process_msg(remote_msg_id, remote_seq, reader, update_state)
# endregion
# region Low level processing
def _send_message(self, message):
async def _send_message(self, message):
"""Sends the given Message(TLObject) encrypted through the network"""
plain_text = \
@@ -130,7 +130,7 @@ class MtProtoSender:
cipher_text = AES.encrypt_ige(plain_text, key, iv)
result = key_id + msg_key + cipher_text
self.connection.send(result)
await self.connection.send(result)
def _decode_msg(self, body):
"""Decodes an received encrypted message body bytes"""
@@ -163,7 +163,7 @@ class MtProtoSender:
return message, remote_msg_id, remote_sequence
def _process_msg(self, msg_id, sequence, reader, state):
async def _process_msg(self, msg_id, sequence, reader, state):
"""Processes and handles a Telegram message.
Returns True if the message was handled correctly and doesn't
@@ -178,22 +178,22 @@ class MtProtoSender:
# The following codes are "parsed manually"
if code == 0xf35c6d01: # rpc_result, (response of an RPC call)
return self._handle_rpc_result(msg_id, sequence, reader)
return await self._handle_rpc_result(msg_id, sequence, reader)
if code == 0x347773c5: # pong
return self._handle_pong(msg_id, sequence, reader)
return await self._handle_pong(msg_id, sequence, reader)
if code == 0x73f1f8dc: # msg_container
return self._handle_container(msg_id, sequence, reader, state)
return await self._handle_container(msg_id, sequence, reader, state)
if code == 0x3072cfa1: # gzip_packed
return self._handle_gzip_packed(msg_id, sequence, reader, state)
return await self._handle_gzip_packed(msg_id, sequence, reader, state)
if code == 0xedab447b: # bad_server_salt
return self._handle_bad_server_salt(msg_id, sequence, reader)
return await self._handle_bad_server_salt(msg_id, sequence, reader)
if code == 0xa7eff811: # bad_msg_notification
return self._handle_bad_msg_notification(msg_id, sequence, reader)
return await self._handle_bad_msg_notification(msg_id, sequence, reader)
# msgs_ack, it may handle the request we wanted
if code == 0x62d6b459:
@@ -247,7 +247,7 @@ class MtProtoSender:
r.confirm_received.set()
self._pending_receive.clear()
def _handle_pong(self, msg_id, sequence, reader):
async def _handle_pong(self, msg_id, sequence, reader):
self._logger.debug('Handling pong')
reader.read_int(signed=False) # code
received_msg_id = reader.read_long()
@@ -259,7 +259,7 @@ class MtProtoSender:
return True
def _handle_container(self, msg_id, sequence, reader, state):
async def _handle_container(self, msg_id, sequence, reader, state):
self._logger.debug('Handling container')
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
begin_position = reader.tell_position()
@@ -267,7 +267,7 @@ class MtProtoSender:
# Note that this code is IMPORTANT for skipping RPC results of
# lost requests (i.e., ones from the previous connection session)
try:
if not self._process_msg(inner_msg_id, sequence, reader, state):
if not await self._process_msg(inner_msg_id, sequence, reader, state):
reader.set_position(begin_position + inner_len)
except:
# If any error is raised, something went wrong; skip the packet
@@ -276,7 +276,7 @@ class MtProtoSender:
return True
def _handle_bad_server_salt(self, msg_id, sequence, reader):
async def _handle_bad_server_salt(self, msg_id, sequence, reader):
self._logger.debug('Handling bad server salt')
reader.read_int(signed=False) # code
bad_msg_id = reader.read_long()
@@ -287,11 +287,11 @@ class MtProtoSender:
request = self._pop_request(bad_msg_id)
if request:
self.send(request)
await self.send(request)
return True
def _handle_bad_msg_notification(self, msg_id, sequence, reader):
async def _handle_bad_msg_notification(self, msg_id, sequence, reader):
self._logger.debug('Handling bad message notification')
reader.read_int(signed=False) # code
reader.read_long() # request_id
@@ -318,7 +318,7 @@ class MtProtoSender:
else:
raise error
def _handle_rpc_result(self, msg_id, sequence, reader):
async def _handle_rpc_result(self, msg_id, sequence, reader):
self._logger.debug('Handling RPC result')
reader.read_int(signed=False) # code
request_id = reader.read_long()
@@ -338,7 +338,7 @@ class MtProtoSender:
)
# Acknowledge that we received the error
self._send_acknowledge(request_id)
await self._send_acknowledge(request_id)
if request:
request.rpc_error = error
@@ -366,9 +366,9 @@ class MtProtoSender:
self._logger.debug('Lost request will be skipped.')
return False
def _handle_gzip_packed(self, msg_id, sequence, reader, state):
async def _handle_gzip_packed(self, msg_id, sequence, reader, state):
self._logger.debug('Handling gzip packed data')
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
return self._process_msg(msg_id, sequence, compressed_reader, state)
return await self._process_msg(msg_id, sequence, compressed_reader, state)
# endregion