from typing import Any, Dict, List from src.embeddings.aligned_embeddings import AlignedEmbedder from src.planner.canonical_text import plan_to_text from src.planner.semantic_plan import SemanticPlan class PlanEmbedder: def __init__(self, embedder: AlignedEmbedder | None = None): self.embedder = embedder or AlignedEmbedder(target_dim=512) def embed(self, plan: Any): text = _plan_to_text(plan) return self.embedder.embed_text(text) def _as_list(value: Any) -> List[str]: if value is None: return [] if isinstance(value, list): return [str(v) for v in value if str(v).strip()] if isinstance(value, str): return [value] if value.strip() else [] return [str(value)] def _join(items: List[str]) -> str: cleaned = [i.strip() for i in items if i and i.strip()] return ", ".join(cleaned) def _plan_to_text(plan: Any) -> str: if isinstance(plan, SemanticPlan) or hasattr(plan, "scene"): return plan_to_text(plan) if hasattr(plan, "model_dump"): data = plan.model_dump() elif isinstance(plan, dict): data = plan else: data = {} scene_summary = str(data.get("scene_summary", "")).strip() core = data.get("core_semantics", {}) or {} style = data.get("style_controls", {}) or {} image_c = data.get("image_constraints", {}) or {} audio_c = data.get("audio_constraints", {}) or {} text_c = data.get("text_constraints", {}) or {} parts: List[str] = [] if scene_summary: parts.append(scene_summary) setting = str(core.get("setting", "")).strip() time_of_day = str(core.get("time_of_day", "")).strip() weather = str(core.get("weather", "")).strip() if setting or time_of_day or weather: parts.append( f"Setting: {', '.join([p for p in [setting, time_of_day, weather] if p])}." ) main_subjects = _as_list(core.get("main_subjects")) actions = _as_list(core.get("actions")) if main_subjects: parts.append(f"Subjects: {_join(main_subjects)}.") if actions: parts.append(f"Actions: {_join(actions)}.") visual_style = _as_list(style.get("visual_style")) color_palette = _as_list(style.get("color_palette")) lighting = _as_list(style.get("lighting")) camera = _as_list(style.get("camera")) mood = _as_list(style.get("mood_emotion")) tone = _as_list(style.get("narrative_tone")) if visual_style or color_palette or lighting or camera: parts.append( f"Visual style: {_join(visual_style + color_palette + lighting + camera)}." ) if mood: parts.append(f"Mood: {_join(mood)}.") if tone: parts.append(f"Tone: {_join(tone)}.") objects = _as_list(image_c.get("objects")) environment = _as_list(image_c.get("environment_details")) composition = _as_list(image_c.get("composition")) img_include = _as_list(image_c.get("must_include")) img_avoid = _as_list(image_c.get("must_avoid")) if objects or environment or composition: parts.append( f"Image constraints: {_join(objects + environment + composition)}." ) if img_include: parts.append(f"Image must include: {_join(img_include)}.") if img_avoid: parts.append(f"Image must avoid: {_join(img_avoid)}.") audio_intent = _as_list(audio_c.get("audio_intent")) sound_sources = _as_list(audio_c.get("sound_sources")) ambience = _as_list(audio_c.get("ambience")) tempo = str(audio_c.get("tempo", "")).strip() aud_include = _as_list(audio_c.get("must_include")) aud_avoid = _as_list(audio_c.get("must_avoid")) if audio_intent or sound_sources or ambience: parts.append( f"Audio intent: {_join(audio_intent + sound_sources + ambience)}." ) if tempo: parts.append(f"Audio tempo: {tempo}.") if aud_include: parts.append(f"Audio must include: {_join(aud_include)}.") if aud_avoid: parts.append(f"Audio must avoid: {_join(aud_avoid)}.") keywords = _as_list(text_c.get("keywords")) text_include = _as_list(text_c.get("must_include")) text_avoid = _as_list(text_c.get("must_avoid")) length = str(text_c.get("length", "")).strip() if keywords: parts.append(f"Text keywords: {_join(keywords)}.") if text_include: parts.append(f"Text must include: {_join(text_include)}.") if text_avoid: parts.append(f"Text must avoid: {_join(text_avoid)}.") if length: parts.append(f"Text length: {length}.") return " ".join([p.strip() for p in parts if p.strip()])