| import gc |
| import os |
| import random |
| import numpy as np |
| import json |
| import torch |
| import uuid |
| from PIL import Image, PngImagePlugin |
| from datetime import datetime |
| from dataclasses import dataclass |
| from typing import Callable, Dict, Optional, Tuple |
| from diffusers import ( |
| DDIMScheduler, |
| DPMSolverMultistepScheduler, |
| DPMSolverSinglestepScheduler, |
| EulerAncestralDiscreteScheduler, |
| EulerDiscreteScheduler, |
| ) |
|
|
| MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
| @dataclass |
| class StyleConfig: |
| prompt: str |
| negative_prompt: str |
|
|
|
|
| def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: |
| if randomize_seed: |
| seed = random.randint(0, MAX_SEED) |
| return seed |
|
|
|
|
| def seed_everything(seed: int) -> torch.Generator: |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| np.random.seed(seed) |
| generator = torch.Generator() |
| generator.manual_seed(seed) |
| return generator |
|
|
|
|
| def parse_aspect_ratio(aspect_ratio: str) -> Optional[Tuple[int, int]]: |
| if aspect_ratio == "Custom": |
| return None |
| width, height = aspect_ratio.split(" x ") |
| return int(width), int(height) |
|
|
|
|
| def aspect_ratio_handler( |
| aspect_ratio: str, custom_width: int, custom_height: int |
| ) -> Tuple[int, int]: |
| if aspect_ratio == "Custom": |
| return custom_width, custom_height |
| else: |
| width, height = parse_aspect_ratio(aspect_ratio) |
| return width, height |
|
|
|
|
| def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]: |
| scheduler_factory_map = { |
| "DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config( |
| scheduler_config, use_karras_sigmas=True |
| ), |
| "DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config( |
| scheduler_config, use_karras_sigmas=True |
| ), |
| "DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config( |
| scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++" |
| ), |
| "Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config), |
| "Euler a": lambda: EulerAncestralDiscreteScheduler.from_config( |
| scheduler_config |
| ), |
| "DDIM": lambda: DDIMScheduler.from_config(scheduler_config), |
| } |
| return scheduler_factory_map.get(name, lambda: None)() |
|
|
|
|
| def free_memory() -> None: |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
|
|
| def preprocess_prompt( |
| style_dict, |
| style_name: str, |
| positive: str, |
| negative: str = "", |
| add_style: bool = True, |
| ) -> Tuple[str, str]: |
| p, n = style_dict.get(style_name, style_dict["(None)"]) |
|
|
| if add_style and positive.strip(): |
| formatted_positive = p.format(prompt=positive) |
| else: |
| formatted_positive = positive |
|
|
| combined_negative = n |
| if negative.strip(): |
| if combined_negative: |
| combined_negative += ", " + negative |
| else: |
| combined_negative = negative |
|
|
| return formatted_positive, combined_negative |
|
|
|
|
| def common_upscale( |
| samples: torch.Tensor, |
| width: int, |
| height: int, |
| upscale_method: str, |
| ) -> torch.Tensor: |
| return torch.nn.functional.interpolate( |
| samples, size=(height, width), mode=upscale_method |
| ) |
|
|
|
|
| def upscale( |
| samples: torch.Tensor, upscale_method: str, scale_by: float |
| ) -> torch.Tensor: |
| width = round(samples.shape[3] * scale_by) |
| height = round(samples.shape[2] * scale_by) |
| return common_upscale(samples, width, height, upscale_method) |
|
|
|
|
| def load_wildcard_files(wildcard_dir: str) -> Dict[str, str]: |
| wildcard_files = {} |
| for file in os.listdir(wildcard_dir): |
| if file.endswith(".txt"): |
| key = f"__{file.split('.')[0]}__" |
| wildcard_files[key] = os.path.join(wildcard_dir, file) |
| return wildcard_files |
|
|
|
|
| def get_random_line_from_file(file_path: str) -> str: |
| with open(file_path, "r") as file: |
| lines = file.readlines() |
| if not lines: |
| return "" |
| return random.choice(lines).strip() |
|
|
|
|
| def add_wildcard(prompt: str, wildcard_files: Dict[str, str]) -> str: |
| for key, file_path in wildcard_files.items(): |
| if key in prompt: |
| wildcard_line = get_random_line_from_file(file_path) |
| prompt = prompt.replace(key, wildcard_line) |
| return prompt |
|
|
|
|
| def preprocess_image_dimensions(width, height): |
| if width % 8 != 0: |
| width = width - (width % 8) |
| if height % 8 != 0: |
| height = height - (height % 8) |
| return width, height |
|
|
|
|
| def save_image(image, metadata, output_dir, is_colab): |
| if is_colab: |
| current_time = datetime.now().strftime("%Y%m%d_%H%M%S") |
| filename = f"image_{current_time}.png" |
| else: |
| filename = str(uuid.uuid4()) + ".png" |
| os.makedirs(output_dir, exist_ok=True) |
| filepath = os.path.join(output_dir, filename) |
| metadata_str = json.dumps(metadata) |
| info = PngImagePlugin.PngInfo() |
| info.add_text("metadata", metadata_str) |
| image.save(filepath, "PNG", pnginfo=info) |
| return filepath |
| |
| |
| def is_google_colab(): |
| try: |
| import google.colab |
| return True |
| except: |
| return False |
|
|