Stick to the offset and limit CdnFileHashes dictates (#222)

The old intersection method and allowing an arbitrary part size
wasn't working properly. Assuming that Telegram will send a sha
sum for the whole file, in the correct order, we can simply use
their offsets to download the file.
This commit is contained in:
Lonami Exo 2017-09-05 16:43:53 +02:00
parent 2924912931
commit 49555ad018
2 changed files with 31 additions and 46 deletions

View File

@ -2,7 +2,7 @@ from hashlib import sha256
from ..tl import Session from ..tl import Session
from ..tl.functions.upload import GetCdnFileRequest, ReuploadCdnFileRequest from ..tl.functions.upload import GetCdnFileRequest, ReuploadCdnFileRequest
from ..tl.types.upload import CdnFileReuploadNeeded from ..tl.types.upload import CdnFileReuploadNeeded, CdnFile
from ..crypto import AESModeCTR from ..crypto import AESModeCTR
from ..errors import CdnFileTamperedError from ..errors import CdnFileTamperedError
@ -20,7 +20,7 @@ class CdnDecrypter:
self.shaes = [sha256() for _ in range(len(cdn_file_hashes))] self.shaes = [sha256() for _ in range(len(cdn_file_hashes))]
@staticmethod @staticmethod
def prepare_decrypter(client, client_cls, cdn_redirect, offset, part_size): def prepare_decrypter(client, client_cls, cdn_redirect):
"""Prepares a CDN decrypter, returning (decrypter, file data). """Prepares a CDN decrypter, returning (decrypter, file data).
'client' should be the original TelegramBareClient that 'client' should be the original TelegramBareClient that
tried to download the file. tried to download the file.
@ -31,8 +31,8 @@ class CdnDecrypter:
# https://core.telegram.org/cdn # https://core.telegram.org/cdn
cdn_aes = AESModeCTR( cdn_aes = AESModeCTR(
key=cdn_redirect.encryption_key, key=cdn_redirect.encryption_key,
iv= # 12 first bytes of the IV..4 bytes of the offset (0, big endian)
cdn_redirect.encryption_iv[:12] + (offset >> 4).to_bytes(4, 'big') iv=cdn_redirect.encryption_iv[:12] + bytes(4)
) )
# Create a new client on said CDN # Create a new client on said CDN
@ -44,9 +44,14 @@ class CdnDecrypter:
session, client.api_id, client.api_hash, session, client.api_id, client.api_hash,
timeout=client._timeout timeout=client._timeout
) )
# This will make use of the new RSA keys for this specific CDN # This will make use of the new RSA keys for this specific CDN.
#
# We assume that cdn_redirect.cdn_file_hashes are ordered by offset,
# and that there will be enough of these to retrieve the whole file.
cdn_file = cdn_client.connect(initial_query=GetCdnFileRequest( cdn_file = cdn_client.connect(initial_query=GetCdnFileRequest(
cdn_redirect.file_token, offset, part_size file_token=cdn_redirect.file_token,
offset=cdn_redirect.cdn_file_hashes[0].offset,
limit=cdn_redirect.cdn_file_hashes[0].limit
)) ))
# CDN client is ready, create the resulting CdnDecrypter # CDN client is ready, create the resulting CdnDecrypter
@ -63,51 +68,32 @@ class CdnDecrypter:
)) ))
# We want to always return a valid upload.CdnFile # We want to always return a valid upload.CdnFile
cdn_file = decrypter.get_file(offset, part_size) cdn_file = decrypter.get_file()
else: else:
cdn_file.bytes = decrypter.cdn_aes.encrypt(cdn_file.bytes) cdn_file.bytes = decrypter.cdn_aes.encrypt(cdn_file.bytes)
decrypter.check(offset, cdn_file.bytes) cdn_hash = decrypter.cdn_file_hashes.pop(0)
decrypter.check(cdn_file.bytes, cdn_hash)
return decrypter, cdn_file return decrypter, cdn_file
def get_file(self, offset, limit): def get_file(self):
"""Calls GetCdnFileRequest and decrypts its bytes. """Calls GetCdnFileRequest and decrypts its bytes.
Also ensures that the file hasn't been tampered. Also ensures that the file hasn't been tampered.
""" """
result = self.client(GetCdnFileRequest(self.file_token, offset, limit)) if self.cdn_file_hashes:
result.bytes = self.cdn_aes.encrypt(result.bytes) cdn_hash = self.cdn_file_hashes.pop(0)
self.check(offset, result.bytes) cdn_file = self.client(GetCdnFileRequest(
return result self.file_token, cdn_hash.offset, cdn_hash.limit
))
def check(self, offset, data): cdn_file.bytes = self.cdn_aes.encrypt(cdn_file.bytes)
"""Checks the integrity of the given data""" self.check(cdn_file.bytes, cdn_hash)
for cdn_hash, sha in zip(self.cdn_file_hashes, self.shaes):
inter = self.intersect(
cdn_hash.offset, cdn_hash.offset + cdn_hash.limit,
offset, offset + len(data)
)
if inter:
x1, x2 = inter[0] - offset, inter[1] - offset
sha.update(data[x1:x2])
elif offset > cdn_hash.offset:
if cdn_hash.hash == sha.digest():
self.cdn_file_hashes.remove(cdn_hash)
self.shaes.remove(sha)
else: else:
raise CdnFileTamperedError() cdn_file = CdnFile(bytes(0))
def finish_check(self): return cdn_file
"""Similar to the check method, but for all unchecked hashes"""
for cdn_hash, sha in zip(self.cdn_file_hashes, self.shaes):
if cdn_hash.hash != sha.digest():
raise CdnFileTamperedError()
self.cdn_file_hashes.clear()
self.shaes.clear()
@staticmethod @staticmethod
def intersect(x1, x2, z1, z2): def check(data, cdn_hash):
if x1 > z1: """Checks the integrity of the given data"""
return None if x1 > z2 else (x1, min(x2, z2)) if sha256(data).digest() != cdn_hash.hash:
else: raise CdnFileTamperedError()
return (z1, min(x2, z2)) if x2 > z1 else None

View File

@ -475,7 +475,7 @@ class TelegramBareClient:
try: try:
if cdn_decrypter: if cdn_decrypter:
result = cdn_decrypter.get_file(offset, part_size) result = cdn_decrypter.get_file()
else: else:
result = client(GetFileRequest( result = client(GetFileRequest(
input_location, offset, part_size input_location, offset, part_size
@ -484,8 +484,7 @@ class TelegramBareClient:
if isinstance(result, FileCdnRedirect): if isinstance(result, FileCdnRedirect):
cdn_decrypter, result = \ cdn_decrypter, result = \
CdnDecrypter.prepare_decrypter( CdnDecrypter.prepare_decrypter(
client, TelegramBareClient, result, client, TelegramBareClient, result
offset, part_size
) )
except FileMigrateError as e: except FileMigrateError as e: