File size: 7,255 Bytes
ed37502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
"""Prompt template system using YAML definitions and Jinja2 rendering.

Templates define structured prompts with variable slots for character traits,
poses, outfits, emotions, camera angles, lighting, and scenes. The engine
renders these templates with provided variables to produce final prompts
for ComfyUI workflows.
"""

from __future__ import annotations

import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import yaml
from jinja2 import Environment, BaseLoader

logger = logging.getLogger(__name__)

IS_HF_SPACES = os.environ.get("HF_SPACES") == "1" or os.environ.get("SPACE_ID") is not None
PROMPTS_DIR = Path("/app/config/templates/prompts") if IS_HF_SPACES else Path("D:/AI automation/content_engine/config/templates/prompts")


@dataclass
class VariableDefinition:
    """Definition of a template variable with its allowed values."""

    name: str
    type: str = "choice"  # choice | string | number
    options: list[str] = field(default_factory=list)
    default: str = ""
    required: bool = False
    description: str = ""


@dataclass
class PromptTemplate:
    """A parsed prompt template."""

    id: str
    name: str
    category: str = ""
    rating: str = "sfw"  # sfw | nsfw
    base_model: str = "realistic_vision"

    # LoRA specs (with Jinja2 variable references)
    loras: list[dict[str, Any]] = field(default_factory=list)

    # Prompt text (Jinja2 templates)
    positive_prompt: str = ""
    negative_prompt: str = ""

    # Sampler defaults
    steps: int | None = None
    cfg: float | None = None
    sampler_name: str | None = None
    scheduler: str | None = None
    width: int | None = None
    height: int | None = None

    # Variable definitions
    variables: dict[str, VariableDefinition] = field(default_factory=dict)

    # Motion (for future video support)
    motion: dict[str, Any] = field(default_factory=dict)


class TemplateEngine:
    """Loads, manages, and renders prompt templates."""

    def __init__(self, templates_dir: Path | None = None):
        self.templates_dir = templates_dir or PROMPTS_DIR
        self._templates: dict[str, PromptTemplate] = {}
        self._jinja_env = Environment(loader=BaseLoader())

    def load_all(self) -> None:
        """Load all YAML templates from the templates directory."""
        if not self.templates_dir.exists():
            logger.warning("Templates directory does not exist: %s", self.templates_dir)
            return

        for path in self.templates_dir.glob("*.yaml"):
            try:
                template = self._parse_template(path)
                self._templates[template.id] = template
                logger.info("Loaded template: %s", template.id)
            except Exception:
                logger.error("Failed to load template %s", path, exc_info=True)

    def _parse_template(self, path: Path) -> PromptTemplate:
        """Parse a YAML file into a PromptTemplate."""
        with open(path) as f:
            data = yaml.safe_load(f)

        variables = {}
        for var_name, var_def in data.get("variables", {}).items():
            variables[var_name] = VariableDefinition(
                name=var_name,
                type=var_def.get("type", "string"),
                options=var_def.get("options", []),
                default=var_def.get("default", ""),
                required=var_def.get("required", False),
                description=var_def.get("description", ""),
            )

        sampler = data.get("sampler", {})

        return PromptTemplate(
            id=data.get("id", path.stem),
            name=data.get("name", path.stem),
            category=data.get("category", ""),
            rating=data.get("rating", "sfw"),
            base_model=data.get("base_model", "realistic_vision"),
            loras=data.get("loras", []),
            positive_prompt=data.get("positive_prompt", ""),
            negative_prompt=data.get("negative_prompt", ""),
            steps=sampler.get("steps"),
            cfg=sampler.get("cfg"),
            sampler_name=sampler.get("sampler_name"),
            scheduler=sampler.get("scheduler"),
            width=sampler.get("width"),
            height=sampler.get("height"),
            variables=variables,
            motion=data.get("motion", {}),
        )

    def get(self, template_id: str) -> PromptTemplate:
        """Get a loaded template by ID."""
        if template_id not in self._templates:
            raise KeyError(f"Template not found: {template_id}")
        return self._templates[template_id]

    def list_templates(self) -> list[PromptTemplate]:
        """List all loaded templates."""
        return list(self._templates.values())

    def render(
        self,
        template_id: str,
        variables: dict[str, str],
    ) -> RenderedPrompt:
        """Render a template with the given variables.

        Returns the rendered positive/negative prompts and resolved LoRA specs.
        """
        template = self.get(template_id)

        # Fill in defaults for missing variables
        resolved_vars = {}
        for var_name, var_def in template.variables.items():
            if var_name in variables:
                resolved_vars[var_name] = variables[var_name]
            elif var_def.default:
                resolved_vars[var_name] = var_def.default
            elif var_def.required:
                # Character-specific vars default to empty when no character selected
                if var_name in ("character_trigger", "character_lora"):
                    resolved_vars[var_name] = ""
                else:
                    raise ValueError(f"Required variable '{var_name}' not provided")

        # Also pass through any extra variables not in the definition
        for k, v in variables.items():
            if k not in resolved_vars:
                resolved_vars[k] = v

        # Render prompts
        positive = self._render_string(template.positive_prompt, resolved_vars)
        negative = self._render_string(template.negative_prompt, resolved_vars)

        # Render LoRA names (they may contain {{character_lora}} etc.)
        rendered_loras = []
        for lora_spec in template.loras:
            rendered_loras.append({
                "name": self._render_string(lora_spec.get("name", ""), resolved_vars),
                "strength_model": lora_spec.get("strength_model", 0.85),
                "strength_clip": lora_spec.get("strength_clip", 0.85),
            })

        return RenderedPrompt(
            positive_prompt=positive,
            negative_prompt=negative,
            loras=rendered_loras,
            variables=resolved_vars,
            template=template,
        )

    def _render_string(self, template_str: str, variables: dict[str, str]) -> str:
        """Render a Jinja2 template string with variables."""
        if not template_str:
            return ""
        tmpl = self._jinja_env.from_string(template_str)
        return tmpl.render(**variables)


@dataclass
class RenderedPrompt:
    """Result of rendering a template with variables."""

    positive_prompt: str
    negative_prompt: str
    loras: list[dict[str, Any]]
    variables: dict[str, str]
    template: PromptTemplate