| """Mòdul per a l'agent de "reflexion".
|
|
|
| Entrenament:
|
|
|
| - A partir de parelles (une_ad_auto, une_ad_hitl) per a cada sha1sum, es
|
| comparen les pistes d'audiodescripció (línies amb "(AD)") amb intervals
|
| de temps coincidents.
|
| - Per a cada pista es calcula la durada i les longituds (caràcters i paraules)
|
| i s'etiqueta el cas com S/E/R/X/C:
|
| * S: mateixa longitud aproximada.
|
| * E: alargament de la frase.
|
| * R: reducció de la frase.
|
| * X: eliminació de la frase a la versió HITL.
|
| * C: creació de frase, la versió automàtica era buida/inexistent.
|
| - Es desa un CSV amb les mostres i s'entrena un KNN (K=5) que assigna
|
| probabilitats a cadascun dels casos.
|
|
|
| Aplicació:
|
|
|
| - Per a un SRT donat, es calculen les mateixes variables per a cada pista
|
| d'(AD) i s'aplica el model KNN per decidir S/E/R/X/C.
|
| - S/C: es deixa el text tal qual.
|
| - X: s'elimina la pista.
|
| - E/R: es demana a GPT-4o-mini que alargui/curti lleugerament la frase,
|
| en una sola crida per a totes les frases afectades.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import csv
|
| import json
|
| import logging
|
| import math
|
| import os
|
| from dataclasses import dataclass
|
| from pathlib import Path
|
| from typing import Dict, Iterable, List, Optional, Tuple
|
|
|
| from langchain_core.messages import HumanMessage, SystemMessage
|
| from langchain_openai import ChatOpenAI
|
|
|
| try:
|
| from sklearn.neighbors import KNeighborsClassifier
|
| import joblib
|
| except Exception:
|
| KNeighborsClassifier = None
|
| joblib = None
|
|
|
| from .introspection import _iter_une_vs_hitl_pairs
|
|
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| BASE_DIR = Path(__file__).resolve().parent
|
| REFINEMENT_TEMP_DIR = BASE_DIR / "temp"
|
| REFINEMENT_TEMP_DIR.mkdir(exist_ok=True, parents=True)
|
|
|
| REFLEXION_CSV_PATH = REFINEMENT_TEMP_DIR / "reflexion.csv"
|
| REFLEXION_MODEL_PATH = REFINEMENT_TEMP_DIR / "reflexion_knn.joblib"
|
|
|
|
|
| @dataclass
|
| class AdCue:
|
| start: float
|
| end: float
|
| text: str
|
| block_lines: List[str]
|
|
|
| @property
|
| def duration(self) -> float:
|
| return max(0.0, self.end - self.start)
|
|
|
| @property
|
| def char_len(self) -> int:
|
| return len(self.text)
|
|
|
| @property
|
| def word_len(self) -> int:
|
| return len(self.text.split())
|
|
|
|
|
| def _parse_timestamp(ts: str) -> float:
|
| """Converteix un timestamp SRT HH:MM:SS,mmm a segons."""
|
|
|
| try:
|
| hh, mm, rest = ts.split(":")
|
| ss, ms = rest.split(",")
|
| return int(hh) * 3600 + int(mm) * 60 + int(ss) + int(ms) / 1000.0
|
| except Exception:
|
| return 0.0
|
|
|
|
|
| def _parse_srt_ad_cues(srt_content: str) -> List[AdCue]:
|
| """Extreu pistes d'(AD) d'un SRT.
|
|
|
| Retorna una llista d'AdCue amb start/end, text (sense el prefix "(AD)") i
|
| les línies de bloc originals per poder reconstruir l'SRT.
|
| """
|
|
|
| lines = srt_content.splitlines()
|
| i = 0
|
| cues: List[AdCue] = []
|
|
|
| while i < len(lines):
|
|
|
| if not lines[i].strip():
|
| i += 1
|
| continue
|
|
|
|
|
| idx_line = lines[i].strip()
|
| i += 1
|
| if i >= len(lines):
|
| break
|
|
|
|
|
| if "-->" not in lines[i]:
|
|
|
| continue
|
|
|
| time_line = lines[i].strip()
|
| i += 1
|
| try:
|
| start_str, end_str = [part.strip() for part in time_line.split("-->")]
|
| except ValueError:
|
| continue
|
|
|
| start = _parse_timestamp(start_str)
|
| end = _parse_timestamp(end_str)
|
|
|
| text_lines: List[str] = []
|
| while i < len(lines) and lines[i].strip():
|
| text_lines.append(lines[i])
|
| i += 1
|
|
|
|
|
|
|
|
|
| ad_text_parts: List[str] = []
|
| for tl in text_lines:
|
| if "(AD)" in tl:
|
|
|
| after = tl.split("(AD)", 1)[1].strip()
|
| if after:
|
| ad_text_parts.append(after)
|
|
|
| if not ad_text_parts:
|
| continue
|
|
|
| ad_text = " ".join(ad_text_parts).strip()
|
| block_lines = [idx_line, time_line] + text_lines
|
| cues.append(AdCue(start=start, end=end, text=ad_text, block_lines=block_lines))
|
|
|
| return cues
|
|
|
|
|
| def _intervals_overlap(a_start: float, a_end: float, b_start: float, b_end: float) -> bool:
|
| return max(a_start, b_start) < min(a_end, b_end)
|
|
|
|
|
| def _build_training_rows() -> List[Tuple[float, int, int, str]]:
|
| """Construeix files d'entrenament (dur, chars, words, label) a partir de
|
| les parelles (une_ad_auto, une_ad_hitl).
|
| """
|
|
|
| rows: List[Tuple[float, int, int, str]] = []
|
|
|
| for sha1sum, une_auto, une_hitl in _iter_une_vs_hitl_pairs():
|
| auto_cues = _parse_srt_ad_cues(une_auto)
|
| hitl_cues = _parse_srt_ad_cues(une_hitl)
|
|
|
|
|
| for ac in auto_cues:
|
|
|
| matching: Optional[AdCue] = None
|
| for hc in hitl_cues:
|
| if _intervals_overlap(ac.start, ac.end, hc.start, hc.end):
|
| matching = hc
|
| break
|
|
|
| if matching is None:
|
|
|
| if ac.text.strip():
|
| rows.append((ac.duration, ac.char_len, ac.word_len, "X"))
|
| continue
|
|
|
|
|
| auto_text = ac.text.strip()
|
| hitl_text = matching.text.strip()
|
|
|
| if not auto_text and hitl_text:
|
|
|
| rows.append((matching.duration, 0, 0, "C"))
|
| continue
|
|
|
| if not auto_text and not hitl_text:
|
| continue
|
|
|
|
|
| auto_chars = len(auto_text)
|
| hitl_chars = len(hitl_text)
|
|
|
|
|
| diff = hitl_chars - auto_chars
|
| if abs(diff) <= max(5, 0.1 * auto_chars):
|
| label = "S"
|
| elif diff > 0:
|
| label = "E"
|
| else:
|
| label = "R"
|
|
|
| rows.append((ac.duration, ac.char_len, ac.word_len, label))
|
|
|
|
|
| for hc in hitl_cues:
|
| has_auto = any(
|
| _intervals_overlap(hc.start, hc.end, ac.start, ac.end) for ac in auto_cues
|
| )
|
| if not has_auto and hc.text.strip():
|
| rows.append((hc.duration, 0, 0, "C"))
|
|
|
| return rows
|
|
|
|
|
| def train_reflexion_model(max_examples: Optional[int] = None) -> None:
|
| """Entrena el model KNN de reflexion i desa CSV + model.
|
|
|
| - Construeix ``reflexion.csv`` amb files ``duracion,char_len,word_len,label``.
|
| - Entrena un KNN (K=5) i el desa a ``reflexion_knn.joblib``.
|
| """
|
|
|
| if KNeighborsClassifier is None or joblib is None:
|
| logger.warning(
|
| "sklearn/joblib no disponibles; el mòdul de reflexion no es pot entrenar."
|
| )
|
| return
|
|
|
| rows = _build_training_rows()
|
| if not rows:
|
| logger.warning("No s'han pogut generar files d'entrenament per a reflexion.")
|
| return
|
|
|
| if max_examples is not None:
|
| rows = rows[:max_examples]
|
|
|
|
|
| with REFLEXION_CSV_PATH.open("w", newline="", encoding="utf-8") as f:
|
| writer = csv.writer(f)
|
| writer.writerow(["duration", "char_len", "word_len", "label"])
|
| for dur, cl, wl, lab in rows:
|
| writer.writerow([f"{dur:.3f}", cl, wl, lab])
|
|
|
| X = [[dur, cl, wl] for dur, cl, wl, _ in rows]
|
| y = [lab for _, _, _, lab in rows]
|
|
|
| knn = KNeighborsClassifier(n_neighbors=5, weights="distance")
|
| knn.fit(X, y)
|
|
|
| joblib.dump(knn, REFLEXION_MODEL_PATH)
|
| logger.info(
|
| "Model de reflexion entrenat amb %d mostres i desat a %s",
|
| len(rows),
|
| REFLEXION_MODEL_PATH,
|
| )
|
|
|
|
|
| def _load_reflexion_model():
|
| if KNeighborsClassifier is None or joblib is None:
|
| return None
|
| if not REFLEXION_MODEL_PATH.exists():
|
| return None
|
| try:
|
| return joblib.load(REFLEXION_MODEL_PATH)
|
| except Exception:
|
| logger.warning("No s'ha pogut carregar el model de reflexion de %s", REFLEXION_MODEL_PATH)
|
| return None
|
|
|
|
|
| def _get_llm() -> Optional[ChatOpenAI]:
|
| api_key = os.environ.get("OPENAI_API_KEY")
|
| if not api_key:
|
| logger.warning("OPENAI_API_KEY no está configurada; se omite la reflexion.")
|
| return None
|
| try:
|
| return ChatOpenAI(model="gpt-4o-mini", temperature=0.0, api_key=api_key)
|
| except Exception as exc:
|
| logger.error("No se pudo inicializar ChatOpenAI para reflexion: %s", exc)
|
| return None
|
|
|
|
|
| def _apply_knn_to_cues(cues: List[AdCue]) -> List[str]:
|
| """Retorna una etiqueta S/E/R/X/C per a cada cue.
|
|
|
| Per simplicitat, les pistes amb durada o longitud zero es marquen com "S" si
|
| no hi ha model.
|
| """
|
|
|
| model = _load_reflexion_model()
|
| if model is None:
|
| return ["S" for _ in cues]
|
|
|
| X = [[c.duration, c.char_len, c.word_len] for c in cues]
|
| try:
|
| probs = model.predict_proba(X)
|
| classes = list(model.classes_)
|
| labels: List[str] = []
|
| for p in probs:
|
| idx = int(p.argmax())
|
| labels.append(str(classes[idx]))
|
| return labels
|
| except Exception as exc:
|
| logger.error("Error aplicant el model de reflexion: %s", exc)
|
| return ["S" for _ in cues]
|
|
|
|
|
| def _ask_llm_for_length_adjustments(cues: List[AdCue], labels: List[str]) -> Dict[int, str]:
|
| """Demana al LLM que alargui/curti frases segons E/R.
|
|
|
| Retorna un mapa {index_cue -> nou_text}."""
|
|
|
| llm = _get_llm()
|
| if llm is None:
|
| return {}
|
|
|
| items: List[Dict[str, str]] = []
|
| for idx, (cue, lab) in enumerate(zip(cues, labels)):
|
| if lab not in {"E", "R"}:
|
| continue
|
| items.append({"id": str(idx), "case": lab, "text": cue.text})
|
|
|
| if not items:
|
| return {}
|
|
|
| system = SystemMessage(
|
| content=(
|
| "Ets un assistent que ajusta lleugerament la longitud de frases d'"
|
| "audiodescripció en català. \n"
|
| "Rebràs una llista d'objectes JSON amb camps 'id', 'case' (E o R) i "
|
| "'text'. \n"
|
| "Per a cada element has de tornar un nou text que: \n"
|
| "- Si 'case' és 'E': sigui una mica més llarg (afegint detalls" \
|
| " suaus, sense canviar el sentit).\n"
|
| "- Si 'case' és 'R': sigui una mica més curt, més concís, mantenint el" \
|
| " sentit principal.\n"
|
| "Respon EXCLUSIVAMENT en JSON de la forma:\n"
|
| "{\"segments\":[{\"id\":\"...\",\"new_text\":\"...\"}, ...]}"
|
| )
|
| )
|
|
|
| user = HumanMessage(content=json.dumps({"segments": items}, ensure_ascii=False))
|
|
|
| try:
|
| resp = llm.invoke([system, user])
|
| except Exception as exc:
|
| logger.error("Error llamando al LLM en reflexion (ajustes E/R): %s", exc)
|
| return {}
|
|
|
| text = resp.content if isinstance(resp.content, str) else str(resp.content)
|
| try:
|
| data = json.loads(text)
|
| except json.JSONDecodeError:
|
| logger.warning("Respuesta del LLM en reflexion no es JSON válido: %s", text[:2000])
|
| return {}
|
|
|
| result: Dict[int, str] = {}
|
| for seg in data.get("segments", []):
|
| try:
|
| idx = int(seg.get("id"))
|
| except Exception:
|
| continue
|
| new_text = str(seg.get("new_text", "")).strip()
|
| if new_text:
|
| result[idx] = new_text
|
|
|
| return result
|
|
|
|
|
| def refine_srt_with_reflexion(srt_content: str) -> str:
|
| """Aplica el pas de "reflexion" sobre un SRT.
|
|
|
| - Usa un model KNN entrenat per decidir, per a cada pista d'(AD), si cal
|
| mantenir-la, eliminar-la o ajustar-ne la longitud.
|
| - Per a casos E/R, demana al LLM una versió lleugerament més llarga/curta.
|
| - Si no hi ha model o LLM, retorna el SRT original.
|
| """
|
|
|
| cues = _parse_srt_ad_cues(srt_content)
|
| if not cues:
|
| return srt_content
|
|
|
| labels = _apply_knn_to_cues(cues)
|
|
|
|
|
| adjustments = _ask_llm_for_length_adjustments(cues, labels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| cue_by_interval: Dict[Tuple[float, float], Tuple[int, AdCue]] = {}
|
| for idx, cue in enumerate(cues):
|
| cue_by_interval[(cue.start, cue.end)] = (idx, cue)
|
|
|
| lines = srt_content.splitlines()
|
| i = 0
|
| out_lines: List[str] = []
|
|
|
| while i < len(lines):
|
| if not lines[i].strip():
|
| out_lines.append(lines[i])
|
| i += 1
|
| continue
|
|
|
| idx_line = lines[i]
|
| i += 1
|
| if i >= len(lines):
|
| out_lines.append(idx_line)
|
| break
|
|
|
| time_line = lines[i]
|
| i += 1
|
| if "-->" not in time_line:
|
|
|
| out_lines.append(idx_line)
|
| out_lines.append(time_line)
|
| continue
|
|
|
|
|
| try:
|
| start_str, end_str = [part.strip() for part in time_line.strip().split("-->")]
|
| start = _parse_timestamp(start_str)
|
| end = _parse_timestamp(end_str)
|
| except Exception:
|
| start = end = math.nan
|
|
|
| text_block: List[str] = []
|
| while i < len(lines) and lines[i].strip():
|
| text_block.append(lines[i])
|
| i += 1
|
|
|
| key = (start, end)
|
| if key not in cue_by_interval:
|
|
|
| out_lines.append(idx_line)
|
| out_lines.append(time_line)
|
| out_lines.extend(text_block)
|
| if i < len(lines) and not lines[i].strip():
|
| out_lines.append(lines[i])
|
| i += 1
|
| continue
|
|
|
| cue_idx, cue = cue_by_interval[key]
|
| label = labels[cue_idx] if cue_idx < len(labels) else "S"
|
|
|
| if label == "X":
|
|
|
| if i < len(lines) and not lines[i].strip():
|
| i += 1
|
| continue
|
|
|
|
|
| new_text = adjustments.get(cue_idx)
|
| if new_text:
|
|
|
| new_block: List[str] = []
|
| replaced = False
|
| for tl in text_block:
|
| if "(AD)" in tl and not replaced:
|
| prefix, _ = tl.split("(AD)", 1)
|
| new_block.append(prefix + "(AD) " + new_text)
|
| replaced = True
|
| else:
|
| new_block.append(tl)
|
| text_block = new_block
|
|
|
| out_lines.append(idx_line)
|
| out_lines.append(time_line)
|
| out_lines.extend(text_block)
|
| if i < len(lines) and not lines[i].strip():
|
| out_lines.append(lines[i])
|
| i += 1
|
|
|
| return "\n".join(out_lines)
|
|
|