mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
Fix DAT models download (#16302)
This commit is contained in:
@@ -211,3 +211,80 @@ Requested path was: {path}
|
||||
subprocess.Popen(["explorer.exe", subprocess.check_output(["wslpath", "-w", path])])
|
||||
else:
|
||||
subprocess.Popen(["xdg-open", path])
|
||||
|
||||
|
||||
def load_file_from_url(
|
||||
url: str,
|
||||
*,
|
||||
model_dir: str,
|
||||
progress: bool = True,
|
||||
file_name: str | None = None,
|
||||
hash_prefix: str | None = None,
|
||||
re_download: bool = False,
|
||||
) -> str:
|
||||
"""Download a file from `url` into `model_dir`, using the file present if possible.
|
||||
Returns the path to the downloaded file.
|
||||
|
||||
file_name: if specified, it will be used as the filename, otherwise the filename will be extracted from the url.
|
||||
file is downloaded to {file_name}.tmp then moved to the final location after download is complete.
|
||||
hash_prefix: sha256 hex string, if provided, the hash of the downloaded file will be checked against this prefix.
|
||||
if the hash does not match, the temporary file is deleted and a ValueError is raised.
|
||||
re_download: forcibly re-download the file even if it already exists.
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
import requests
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
except ImportError:
|
||||
class tqdm:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def update(self, n=1, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
if not file_name:
|
||||
parts = urlparse(url)
|
||||
file_name = os.path.basename(parts.path)
|
||||
|
||||
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
|
||||
|
||||
if re_download or not os.path.exists(cached_file):
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
temp_file = os.path.join(model_dir, f"{file_name}.tmp")
|
||||
print(f'\nDownloading: "{url}" to {cached_file}')
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
with tqdm(total=total_size, unit='B', unit_scale=True, desc=file_name, disable=not progress) as progress_bar:
|
||||
with open(temp_file, 'wb') as file:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
file.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
if hash_prefix and not compare_sha256(temp_file, hash_prefix):
|
||||
print(f"Hash mismatch for {temp_file}. Deleting the temporary file.")
|
||||
os.remove(temp_file)
|
||||
raise ValueError(f"File hash does not match the expected hash prefix {hash_prefix}!")
|
||||
|
||||
os.rename(temp_file, cached_file)
|
||||
return cached_file
|
||||
|
||||
|
||||
def compare_sha256(file_path: str, hash_prefix: str) -> bool:
|
||||
"""Check if the SHA256 hash of the file matches the given prefix."""
|
||||
import hashlib
|
||||
hash_sha256 = hashlib.sha256()
|
||||
blksize = 1024 * 1024
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(blksize), b""):
|
||||
hash_sha256.update(chunk)
|
||||
return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower())
|
||||
|
Reference in New Issue
Block a user