re-work multi --styles-file

--styles-file change to append str
--styles-file is [] then defaults to [styles.csv]

--styles-file accepts paths or paths with wildcard "*"

the first `--styles-file` entry is use as the default styles file path
if filename a wildcard then the first matching file is used
if no match is found, create a new "styles.csv" in the same dir as the first path

when saving a new style it will be save in the default styles file
when saving a existing style, it will be saved to file it belongs to

order of the styles files in the styles dropdown can be controlled to a certain degree by the order of --styles-file
This commit is contained in:
w-e-w
2024-01-21 02:21:36 +09:00
parent f939bce845
commit 25e8273d2f
4 changed files with 49 additions and 44 deletions

View File

@@ -1,16 +1,15 @@
from pathlib import Path
import csv
import fnmatch
import os
import os.path
import typing
import shutil
class PromptStyle(typing.NamedTuple):
name: str
prompt: str
negative_prompt: str
path: str = None
prompt: str | None
negative_prompt: str | None
path: str | None = None
def merge_prompts(style_prompt: str, prompt: str) -> str:
@@ -79,14 +78,19 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
class StyleDatabase:
def __init__(self, path: str):
def __init__(self, paths: list[str | Path]):
self.no_style = PromptStyle("None", "", "", None)
self.styles = {}
self.path = path
self.paths = paths
self.all_styles_files: list[Path] = []
folder, file = os.path.split(self.path)
filename, _, ext = file.partition('*')
self.default_path = os.path.join(folder, filename + ext)
folder, file = os.path.split(self.paths[0])
if '*' in file or '?' in file:
# if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path
self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))
self.paths.insert(0, self.default_path)
else:
self.default_path = Path(self.paths[0])
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
@@ -99,33 +103,31 @@ class StyleDatabase:
"""
self.styles.clear()
path, filename = os.path.split(self.path)
# scans for all styles files
all_styles_files = []
for pattern in self.paths:
folder, file = os.path.split(pattern)
if '*' in file or '?' in file:
found_files = Path(folder).glob(file)
[all_styles_files.append(file) for file in found_files]
else:
# if os.path.exists(pattern):
all_styles_files.append(Path(pattern))
if "*" in filename:
fileglob = filename.split("*")[0] + "*.csv"
filelist = []
for file in os.listdir(path):
if fnmatch.fnmatch(file, fileglob):
filelist.append(file)
# Add a visible divider to the style list
half_len = round(len(file) / 2)
divider = f"{'-' * (20 - half_len)} {file.upper()}"
divider = f"{divider} {'-' * (40 - len(divider))}"
self.styles[divider] = PromptStyle(
f"{divider}", None, None, "do_not_save"
)
# Add styles from this CSV file
self.load_from_csv(os.path.join(path, file))
if len(filelist) == 0:
print(f"No styles found in {path} matching {fileglob}")
return
elif not os.path.exists(self.path):
print(f"Style database not found: {self.path}")
return
else:
self.load_from_csv(self.path)
# Remove any duplicate entries
seen = set()
self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]
def load_from_csv(self, path: str):
for styles_file in self.all_styles_files:
if len(all_styles_files) > 1:
# add divider when more than styles file
# '---------------- STYLES ----------------'
divider = f' {styles_file.stem.upper()} '.center(40, '-')
self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
if styles_file.is_file():
self.load_from_csv(styles_file)
def load_from_csv(self, path: str | Path):
with open(path, "r", encoding="utf-8-sig", newline="") as file:
reader = csv.DictReader(file, skipinitialspace=True)
for row in reader:
@@ -137,7 +139,7 @@ class StyleDatabase:
negative_prompt = row.get("negative_prompt", "")
# Add style to database
self.styles[row["name"]] = PromptStyle(
row["name"], prompt, negative_prompt, path
row["name"], prompt, negative_prompt, str(path)
)
def get_style_paths(self) -> set:
@@ -145,11 +147,11 @@ class StyleDatabase:
# Update any styles without a path to the default path
for style in list(self.styles.values()):
if not style.path:
self.styles[style.name] = style._replace(path=self.default_path)
self.styles[style.name] = style._replace(path=str(self.default_path))
# Create a list of all distinct paths, including the default path
style_paths = set()
style_paths.add(self.default_path)
style_paths.add(str(self.default_path))
for _, style in self.styles.items():
if style.path:
style_paths.add(style.path)
@@ -177,7 +179,6 @@ class StyleDatabase:
def save_styles(self, path: str = None) -> None:
# The path argument is deprecated, but kept for backwards compatibility
_ = path
style_paths = self.get_style_paths()