pratik-250620's picture
Upload folder using huggingface_hub
6835659 verified
from typing import Any, Dict, List
from src.embeddings.aligned_embeddings import AlignedEmbedder
from src.planner.canonical_text import plan_to_text
from src.planner.semantic_plan import SemanticPlan
class PlanEmbedder:
def __init__(self, embedder: AlignedEmbedder | None = None):
self.embedder = embedder or AlignedEmbedder(target_dim=512)
def embed(self, plan: Any):
text = _plan_to_text(plan)
return self.embedder.embed_text(text)
def _as_list(value: Any) -> List[str]:
if value is None:
return []
if isinstance(value, list):
return [str(v) for v in value if str(v).strip()]
if isinstance(value, str):
return [value] if value.strip() else []
return [str(value)]
def _join(items: List[str]) -> str:
cleaned = [i.strip() for i in items if i and i.strip()]
return ", ".join(cleaned)
def _plan_to_text(plan: Any) -> str:
if isinstance(plan, SemanticPlan) or hasattr(plan, "scene"):
return plan_to_text(plan)
if hasattr(plan, "model_dump"):
data = plan.model_dump()
elif isinstance(plan, dict):
data = plan
else:
data = {}
scene_summary = str(data.get("scene_summary", "")).strip()
core = data.get("core_semantics", {}) or {}
style = data.get("style_controls", {}) or {}
image_c = data.get("image_constraints", {}) or {}
audio_c = data.get("audio_constraints", {}) or {}
text_c = data.get("text_constraints", {}) or {}
parts: List[str] = []
if scene_summary:
parts.append(scene_summary)
setting = str(core.get("setting", "")).strip()
time_of_day = str(core.get("time_of_day", "")).strip()
weather = str(core.get("weather", "")).strip()
if setting or time_of_day or weather:
parts.append(
f"Setting: {', '.join([p for p in [setting, time_of_day, weather] if p])}."
)
main_subjects = _as_list(core.get("main_subjects"))
actions = _as_list(core.get("actions"))
if main_subjects:
parts.append(f"Subjects: {_join(main_subjects)}.")
if actions:
parts.append(f"Actions: {_join(actions)}.")
visual_style = _as_list(style.get("visual_style"))
color_palette = _as_list(style.get("color_palette"))
lighting = _as_list(style.get("lighting"))
camera = _as_list(style.get("camera"))
mood = _as_list(style.get("mood_emotion"))
tone = _as_list(style.get("narrative_tone"))
if visual_style or color_palette or lighting or camera:
parts.append(
f"Visual style: {_join(visual_style + color_palette + lighting + camera)}."
)
if mood:
parts.append(f"Mood: {_join(mood)}.")
if tone:
parts.append(f"Tone: {_join(tone)}.")
objects = _as_list(image_c.get("objects"))
environment = _as_list(image_c.get("environment_details"))
composition = _as_list(image_c.get("composition"))
img_include = _as_list(image_c.get("must_include"))
img_avoid = _as_list(image_c.get("must_avoid"))
if objects or environment or composition:
parts.append(
f"Image constraints: {_join(objects + environment + composition)}."
)
if img_include:
parts.append(f"Image must include: {_join(img_include)}.")
if img_avoid:
parts.append(f"Image must avoid: {_join(img_avoid)}.")
audio_intent = _as_list(audio_c.get("audio_intent"))
sound_sources = _as_list(audio_c.get("sound_sources"))
ambience = _as_list(audio_c.get("ambience"))
tempo = str(audio_c.get("tempo", "")).strip()
aud_include = _as_list(audio_c.get("must_include"))
aud_avoid = _as_list(audio_c.get("must_avoid"))
if audio_intent or sound_sources or ambience:
parts.append(
f"Audio intent: {_join(audio_intent + sound_sources + ambience)}."
)
if tempo:
parts.append(f"Audio tempo: {tempo}.")
if aud_include:
parts.append(f"Audio must include: {_join(aud_include)}.")
if aud_avoid:
parts.append(f"Audio must avoid: {_join(aud_avoid)}.")
keywords = _as_list(text_c.get("keywords"))
text_include = _as_list(text_c.get("must_include"))
text_avoid = _as_list(text_c.get("must_avoid"))
length = str(text_c.get("length", "")).strip()
if keywords:
parts.append(f"Text keywords: {_join(keywords)}.")
if text_include:
parts.append(f"Text must include: {_join(text_include)}.")
if text_avoid:
parts.append(f"Text must avoid: {_join(text_avoid)}.")
if length:
parts.append(f"Text length: {length}.")
return " ".join([p.strip() for p in parts if p.strip()])