mirror of
https://github.com/meeb/tubesync.git
synced 2025-06-22 21:16:38 +00:00
204 lines
6.7 KiB
Python
204 lines
6.7 KiB
Python
import cProfile
|
|
import emoji
|
|
import io
|
|
import os
|
|
import pstats
|
|
import string
|
|
import time
|
|
from datetime import datetime
|
|
from urllib.parse import urlunsplit, urlencode, urlparse
|
|
from yt_dlp.utils import LazyList
|
|
from .errors import DatabaseConnectionError
|
|
|
|
|
|
def getenv(key, default=None, /, *, integer=False, string=True):
|
|
'''
|
|
Guarantees a returned type from calling `os.getenv`
|
|
The caller can request the integer type,
|
|
or use the default string type.
|
|
'''
|
|
|
|
args = dict(key=key, default=default, integer=integer, string=string)
|
|
supported_types = dict(zip(args.keys(), (
|
|
(str,), # key
|
|
(
|
|
bool,
|
|
float,
|
|
int,
|
|
str,
|
|
None.__class__,
|
|
), # default
|
|
(bool,) * (len(args.keys()) - 2),
|
|
)))
|
|
unsupported_type_msg = 'Unsupported type for positional argument, "{}": {}'
|
|
for k, t in supported_types.items():
|
|
v = args[k]
|
|
assert isinstance(v, t), unsupported_type_msg.format(k, type(v))
|
|
|
|
d = str(default) if default is not None else None
|
|
|
|
r = os.getenv(key, d)
|
|
if r is None:
|
|
if string: r = str()
|
|
if integer: r = int()
|
|
elif integer:
|
|
r = int(float(r))
|
|
return r
|
|
|
|
|
|
def parse_database_connection_string(database_connection_string):
|
|
'''
|
|
Parses a connection string in a URL style format, such as:
|
|
postgresql://tubesync:password@localhost:5432/tubesync
|
|
mysql://someuser:somepassword@localhost:3306/tubesync
|
|
into a Django-compatible settings.DATABASES dict format.
|
|
'''
|
|
valid_drivers = ('postgresql', 'mysql')
|
|
default_ports = {
|
|
'postgresql': 5432,
|
|
'mysql': 3306,
|
|
}
|
|
django_backends = {
|
|
'postgresql': 'django.db.backends.postgresql',
|
|
'mysql': 'django.db.backends.mysql',
|
|
}
|
|
backend_options = {
|
|
'postgresql': {},
|
|
'mysql': {
|
|
'charset': 'utf8mb4',
|
|
}
|
|
}
|
|
try:
|
|
parts = urlparse(str(database_connection_string))
|
|
except Exception as e:
|
|
raise DatabaseConnectionError(f'Failed to parse "{database_connection_string}" '
|
|
f'as a database connection string: {e}') from e
|
|
driver = parts.scheme
|
|
user_pass_host_port = parts.netloc
|
|
database = parts.path
|
|
if driver not in valid_drivers:
|
|
raise DatabaseConnectionError(f'Database connection string '
|
|
f'"{database_connection_string}" specified an '
|
|
f'invalid driver, must be one of {valid_drivers}')
|
|
django_driver = django_backends.get(driver)
|
|
host_parts = user_pass_host_port.split('@')
|
|
if len(host_parts) != 2:
|
|
raise DatabaseConnectionError(f'Database connection string netloc must be in '
|
|
f'the format of user:pass@host')
|
|
user_pass, host_port = host_parts
|
|
user_pass_parts = user_pass.split(':')
|
|
if len(user_pass_parts) != 2:
|
|
raise DatabaseConnectionError(f'Database connection string netloc must be in '
|
|
f'the format of user:pass@host')
|
|
username, password = user_pass_parts
|
|
host_port_parts = host_port.split(':')
|
|
if len(host_port_parts) == 1:
|
|
# No port number, assign a default port
|
|
hostname = host_port_parts[0]
|
|
port = default_ports.get(driver)
|
|
elif len(host_port_parts) == 2:
|
|
# Host name and port number
|
|
hostname, port = host_port_parts
|
|
try:
|
|
port = int(port)
|
|
except (ValueError, TypeError) as e:
|
|
raise DatabaseConnectionError(f'Database connection string contained an '
|
|
f'invalid port, ports must be integers: '
|
|
f'{e}') from e
|
|
if not 0 < port < 63336:
|
|
raise DatabaseConnectionError(f'Database connection string contained an '
|
|
f'invalid port, ports must be between 1 and '
|
|
f'65535, got {port}')
|
|
else:
|
|
# Malformed
|
|
raise DatabaseConnectionError(f'Database connection host must be a hostname or '
|
|
f'a hostname:port combination')
|
|
if database.startswith('/'):
|
|
database = database[1:]
|
|
if not database:
|
|
raise DatabaseConnectionError(f'Database connection string path must be a '
|
|
f'string in the format of /databasename')
|
|
if '/' in database:
|
|
raise DatabaseConnectionError(f'Database connection string path can only '
|
|
f'contain a single string name, got: {database}')
|
|
return {
|
|
'DRIVER': driver,
|
|
'ENGINE': django_driver,
|
|
'NAME': database,
|
|
'USER': username,
|
|
'PASSWORD': password,
|
|
'HOST': hostname,
|
|
'PORT': port,
|
|
'CONN_MAX_AGE': 300,
|
|
'OPTIONS': backend_options.get(driver),
|
|
}
|
|
|
|
|
|
def get_client_ip(request):
|
|
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
|
|
if x_forwarded_for:
|
|
ip = x_forwarded_for.split(',')[0]
|
|
else:
|
|
ip = request.META.get('REMOTE_ADDR')
|
|
return ip
|
|
|
|
|
|
def append_uri_params(uri, params):
|
|
uri = str(uri)
|
|
qs = urlencode(params)
|
|
return urlunsplit(('', '', uri, qs, ''))
|
|
|
|
|
|
def clean_filename(filename):
|
|
if not isinstance(filename, str):
|
|
raise ValueError(f'filename must be a str, got {type(filename)}')
|
|
to_scrub = r'<>\/:*?"|%'
|
|
for char in list(to_scrub):
|
|
filename = filename.replace(char, '')
|
|
clean_filename = ''
|
|
for c in filename:
|
|
if c in string.whitespace:
|
|
c = ' '
|
|
if ord(c) > 30:
|
|
clean_filename += c
|
|
return clean_filename.strip()
|
|
|
|
|
|
def clean_emoji(s):
|
|
if not isinstance(s, str):
|
|
raise ValueError(f'parameter must be a str, got {type(s)}')
|
|
return emoji.replace_emoji(s)
|
|
|
|
|
|
def json_serial(obj):
|
|
if isinstance(obj, datetime):
|
|
return obj.isoformat()
|
|
if isinstance(obj, LazyList):
|
|
return list(obj)
|
|
raise TypeError(f'Type {type(obj)} is not json_serial()-able')
|
|
|
|
|
|
def time_func(func):
|
|
def wrapper(*args, **kwargs):
|
|
start = time.perf_counter()
|
|
result = func(*args, **kwargs)
|
|
end = time.perf_counter()
|
|
return (result, (end - start, start, end,),)
|
|
return wrapper
|
|
|
|
|
|
def profile_func(func):
|
|
def wrapper(*args, **kwargs):
|
|
s = io.StringIO()
|
|
with cProfile.Profile() as pr:
|
|
pr.enable()
|
|
result = func(*args, **kwargs)
|
|
pr.disable()
|
|
ps = pstats.Stats(pr, stream=s)
|
|
ps.sort_stats(
|
|
pstats.SortKey.CUMULATIVE
|
|
).print_stats()
|
|
return (result, (s.getvalue(), ps, s),)
|
|
return wrapper
|
|
|