| |
| from __future__ import annotations |
|
|
| import csv |
| import os |
| import os.path |
| import typing |
| import collections.abc as abc |
| import tempfile |
| import shutil |
|
|
| if typing.TYPE_CHECKING: |
| |
| from .processing import StableDiffusionProcessing |
|
|
|
|
| class PromptStyle(typing.NamedTuple): |
| name: str |
| prompt: str |
| negative_prompt: str |
|
|
|
|
| def merge_prompts(style_prompt: str, prompt: str) -> str: |
| if "{prompt}" in style_prompt: |
| res = style_prompt.replace("{prompt}", prompt) |
| else: |
| parts = filter(None, (prompt.strip(), style_prompt.strip())) |
| res = ", ".join(parts) |
|
|
| return res |
|
|
|
|
| def apply_styles_to_prompt(prompt, styles): |
| for style in styles: |
| prompt = merge_prompts(style, prompt) |
|
|
| return prompt |
|
|
|
|
| class StyleDatabase: |
| def __init__(self, path: str): |
| self.no_style = PromptStyle("None", "", "") |
| self.styles = {"None": self.no_style} |
|
|
| if not os.path.exists(path): |
| return |
|
|
| with open(path, "r", encoding="utf-8-sig", newline='') as file: |
| reader = csv.DictReader(file) |
| for row in reader: |
| |
| prompt = row["prompt"] if "prompt" in row else row["text"] |
| negative_prompt = row.get("negative_prompt", "") |
| self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt) |
|
|
| def get_style_prompts(self, styles): |
| return [self.styles.get(x, self.no_style).prompt for x in styles] |
|
|
| def get_negative_style_prompts(self, styles): |
| return [self.styles.get(x, self.no_style).negative_prompt for x in styles] |
|
|
| def apply_styles_to_prompt(self, prompt, styles): |
| return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles]) |
|
|
| def apply_negative_styles_to_prompt(self, prompt, styles): |
| return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]) |
|
|
| def save_styles(self, path: str) -> None: |
| |
| fd, temp_path = tempfile.mkstemp(".csv") |
| with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: |
| |
| |
| writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) |
| writer.writeheader() |
| writer.writerows(style._asdict() for k, style in self.styles.items()) |
|
|
| |
| if os.path.exists(path): |
| shutil.move(path, path + ".bak") |
| shutil.move(temp_path, path) |
|
|