MultiModal-Coherence-AI / src /coherence /coherence_engine.py
pratik-250620's picture
Upload folder using huggingface_hub
6835659 verified
from __future__ import annotations
import logging
from typing import Any, Dict, Optional
from src.coherence.drift_detector import detect_drift
from src.coherence.scorer import CoherenceScorer
from src.embeddings.aligned_embeddings import AlignedEmbedder
from src.embeddings.similarity import cosine_similarity
logger = logging.getLogger(__name__)
class CoherenceEngine:
"""
Evaluates multimodal coherence using correct embedding spaces.
Text-Image similarity: CLIP shared space
Text-Audio similarity: CLAP shared space
Image-Audio similarity: cross-space (CLIP vs CLAP) — not directly comparable.
We omit si_a from scoring by default since it would compare embeddings
from different model spaces without a trained bridge.
If a trained CrossSpaceBridge is loaded via load_bridge(), si_a is computed
in the learned bridge space and the full MSCI formula activates:
MSCI = 0.45 * st_i + 0.45 * st_a + 0.10 * si_a
"""
def __init__(self, target_dim: int = 512):
self.embedder = AlignedEmbedder(target_dim=target_dim)
self.scorer = CoherenceScorer()
self._bridge = None # Optional CrossSpaceBridge for si_a
def load_bridge(self, path: str) -> None:
"""
Load a trained CrossSpaceBridge to enable image-audio similarity.
Once loaded, si_a will be computed via the bridge's shared space
instead of being set to None.
Args:
path: Path to saved bridge weights (.pt file)
"""
from src.embeddings.cross_space_bridge import CrossSpaceBridge
from pathlib import Path
bridge_path = Path(path)
if not bridge_path.exists():
logger.warning("Bridge file not found: %s — si_a remains disabled", path)
return
self._bridge = CrossSpaceBridge.load(bridge_path)
logger.info("Cross-space bridge loaded — si_a enabled")
def evaluate(
self,
text: str,
image_path: Optional[str] = None,
audio_path: Optional[str] = None,
) -> Dict[str, Any]:
# CLIP text embedding (for text-image comparison)
emb_text_clip = self.embedder.embed_text(text)
# CLAP text embedding (for text-audio comparison)
emb_text_clap = self.embedder.embed_text_for_audio(text) if audio_path else None
emb_image = None
emb_audio = None
if image_path:
emb_image = self.embedder.embed_image(image_path)
if audio_path:
emb_audio = self.embedder.embed_audio(audio_path)
scores: Dict[str, float | None] = {}
# Text-Image: CLIP shared space (meaningful)
if emb_text_clip is not None and emb_image is not None:
scores["st_i"] = float(round(cosine_similarity(emb_text_clip, emb_image), 4))
logger.debug("st_i = %.4f", scores["st_i"])
# Text-Audio: CLAP shared space (meaningful)
if emb_text_clap is not None and emb_audio is not None:
scores["st_a"] = float(round(cosine_similarity(emb_text_clap, emb_audio), 4))
logger.debug("st_a = %.4f", scores["st_a"])
# Image-Audio: cross-space (CLIP image vs CLAP audio)
# Without a bridge, these live in different spaces — similarity is meaningless.
# With a trained bridge, project both into a shared space for si_a.
if self._bridge is not None and emb_image is not None and emb_audio is not None:
scores["si_a"] = float(round(
self._bridge.compute_similarity(emb_image, emb_audio), 4
))
logger.debug("si_a = %.4f (via bridge)", scores["si_a"])
else:
scores["si_a"] = None
# Compute MSCI from available scores
available = {k: v for k, v in scores.items() if v is not None}
if len(available) >= 2:
weights = {"st_i": 0.45, "st_a": 0.45, "si_a": 0.10}
total = sum(weights[k] for k in available if k in weights)
msci = sum(available[k] * weights[k] for k in available if k in weights) / max(total, 1e-6)
scores["msci"] = float(round(msci, 4))
elif len(available) == 1:
scores["msci"] = float(round(list(available.values())[0], 4))
else:
scores["msci"] = None
logger.info("MSCI = %s (from %d pairwise scores)", scores["msci"], len(available))
drift = detect_drift(
scores.get("msci"),
scores.get("st_i"),
scores.get("st_a"),
scores.get("si_a"),
)
coherence = self.scorer.score(scores=scores, global_drift=drift["global_drift"])
return {
"scores": scores,
"drift": drift,
"coherence": coherence,
"classification": coherence["classification"],
"final_score": coherence["final_score"],
}
def evaluate_coherence(
text: str,
image_path: Optional[str] = None,
audio_path: Optional[str] = None,
) -> Dict[str, Any]:
engine = CoherenceEngine()
return engine.evaluate(
text=text,
image_path=image_path,
audio_path=audio_path,
)