Use async def everywhere

This commit is contained in:
Lonami Exo
2017-10-06 21:02:41 +02:00
parent 9716d1d543
commit 77c99db066
7 changed files with 206 additions and 208 deletions

View File

@@ -1,14 +1,13 @@
import errno
import os
import struct
from datetime import timedelta
from zlib import crc32
from enum import Enum
import errno
from zlib import crc32
from ..crypto import AESModeCTR
from ..extensions import TcpClient
from ..errors import InvalidChecksumError
from ..extensions import TcpClient
class ConnectionMode(Enum):
@@ -74,9 +73,9 @@ class Connection:
setattr(self, 'write', self._write_plain)
setattr(self, 'read', self._read_plain)
def connect(self, ip, port):
async def connect(self, ip, port):
try:
self.conn.connect(ip, port)
await self.conn.connect(ip, port)
except OSError as e:
if e.errno == errno.EISCONN:
return # Already connected, no need to re-set everything up
@@ -85,16 +84,16 @@ class Connection:
self._send_counter = 0
if self._mode == ConnectionMode.TCP_ABRIDGED:
self.conn.write(b'\xef')
await self.conn.write(b'\xef')
elif self._mode == ConnectionMode.TCP_INTERMEDIATE:
self.conn.write(b'\xee\xee\xee\xee')
await self.conn.write(b'\xee\xee\xee\xee')
elif self._mode == ConnectionMode.TCP_OBFUSCATED:
self._setup_obfuscation()
await self._setup_obfuscation()
def get_timeout(self):
return self.conn.timeout
def _setup_obfuscation(self):
async def _setup_obfuscation(self):
# Obfuscated messages secrets cannot start with any of these
keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4)
while True:
@@ -119,7 +118,7 @@ class Connection:
self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv)
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
self.conn.write(bytes(random))
await self.conn.write(bytes(random))
def is_connected(self):
return self.conn.connected
@@ -135,20 +134,23 @@ class Connection:
# region Receive message implementations
def recv(self):
async def recv(self):
"""Receives and unpacks a message"""
# Default implementation is just an error
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _recv_tcp_full(self):
packet_length_bytes = self.read(4)
async def _recv_tcp_full(self):
# TODO We don't want another call to this method that could
# potentially await on another self.read(n). Is this guaranteed
# by asyncio?
packet_length_bytes = await self.read(4)
packet_length = int.from_bytes(packet_length_bytes, 'little')
seq_bytes = self.read(4)
seq_bytes = await self.read(4)
seq = int.from_bytes(seq_bytes, 'little')
body = self.read(packet_length - 12)
checksum = int.from_bytes(self.read(4), 'little')
body = await self.read(packet_length - 12)
checksum = int.from_bytes(await self.read(4), 'little')
valid_checksum = crc32(packet_length_bytes + seq_bytes + body)
if checksum != valid_checksum:
@@ -156,72 +158,70 @@ class Connection:
return body
def _recv_intermediate(self):
return self.read(int.from_bytes(self.read(4), 'little'))
async def _recv_intermediate(self):
return await self.read(int.from_bytes(self.read(4), 'little'))
def _recv_abridged(self):
async def _recv_abridged(self):
length = int.from_bytes(self.read(1), 'little')
if length >= 127:
length = int.from_bytes(self.read(3) + b'\0', 'little')
return self.read(length << 2)
return await self.read(length << 2)
# endregion
# region Send message implementations
def send(self, message):
async def send(self, message):
"""Encapsulates and sends the given message"""
# Default implementation is just an error
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _send_tcp_full(self, message):
async def _send_tcp_full(self, message):
# https://core.telegram.org/mtproto#tcp-transport
# total length, sequence number, packet and checksum (CRC32)
length = len(message) + 12
data = struct.pack('<ii', length, self._send_counter) + message
crc = struct.pack('<I', crc32(data))
self._send_counter += 1
self.write(data + crc)
await self.write(data + crc)
def _send_intermediate(self, message):
self.write(struct.pack('<i', len(message)) + message)
async def _send_intermediate(self, message):
await self.write(struct.pack('<i', len(message)) + message)
def _send_abridged(self, message):
async def _send_abridged(self, message):
length = len(message) >> 2
if length < 127:
length = struct.pack('B', length)
else:
length = b'\x7f' + int.to_bytes(length, 3, 'little')
self.write(length + message)
await self.write(length + message)
# endregion
# region Read implementations
def read(self, length):
async def read(self, length):
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _read_plain(self, length):
return self.conn.read(length)
async def _read_plain(self, length):
return await self.conn.read(length)
def _read_obfuscated(self, length):
return self._aes_decrypt.encrypt(
self.conn.read(length)
)
async def _read_obfuscated(self, length):
return await self._aes_decrypt.encrypt(self.conn.read(length))
# endregion
# region Write implementations
def write(self, data):
async def write(self, data):
raise ValueError('Invalid connection mode specified: ' + str(self._mode))
def _write_plain(self, data):
self.conn.write(data)
async def _write_plain(self, data):
await self.conn.write(data)
def _write_obfuscated(self, data):
self.conn.write(self._aes_encrypt.encrypt(data))
async def _write_obfuscated(self, data):
await self.conn.write(self._aes_encrypt.encrypt(data))
# endregion