Spaces:
Sleeping
Sleeping
| """ | |
| MoodSyncAI: Multi-Modal Sentiment & Emotion Analyser | |
| ==================================================== | |
| Components: | |
| - Visual emotion: ViT (Vision Transformer) - trpakov/vit-face-expression | |
| - Text emotion: DistilRoBERTa transformer - j-hartmann/emotion-english-distilroberta-base | |
| - Fusion: Valence-aligned multimodal fusion + mismatch detection | |
| - Generative: FLAN-T5 (with safe template fallback) for plain-language summary | |
| - Webcam: Short video upload/recording, per-frame emotion timeline | |
| All models are free/open-source from Hugging Face. Runs on CPU. | |
| """ | |
| import os | |
| import io | |
| import time | |
| import warnings | |
| from typing import List, Tuple, Dict | |
| warnings.filterwarnings("ignore") | |
| os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error") | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| import cv2 | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| import gradio as gr | |
| import torch | |
| from transformers import ( | |
| pipeline, | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM, | |
| AutoModelForImageClassification, | |
| AutoModelForSequenceClassification, | |
| AutoImageProcessor, | |
| ) | |
| # ------------------------------------------------------------- | |
| # Model identifiers (all free / public on Hugging Face Hub) | |
| # ------------------------------------------------------------- | |
| VISION_MODEL = "trpakov/vit-face-expression" # ViT for facial emotion | |
| TEXT_MODEL = "j-hartmann/emotion-english-distilroberta-base" # 7 emotions | |
| GEN_MODEL = "google/flan-t5-base" # generative summariser | |
| ASR_MODEL = "openai/whisper-tiny" # speech-to-text (Whisper) | |
| DEVICE = 0 if torch.cuda.is_available() else -1 | |
| print(f"[MoodSyncAI] Torch device: {'cuda' if DEVICE == 0 else 'cpu'}") | |
| # ------------------------------------------------------------- | |
| # Lazy-loaded model singletons | |
| # ------------------------------------------------------------- | |
| _vision_pipe = None | |
| _text_pipe = None | |
| _gen_tokenizer = None | |
| _gen_model = None | |
| _face_cascade = None | |
| _asr_pipe = None | |
| _vit_attn_model = None | |
| _vit_attn_processor = None | |
| _text_attn_model = None | |
| _text_attn_tokenizer = None | |
| def get_vision_pipe(): | |
| global _vision_pipe | |
| if _vision_pipe is None: | |
| print("[MoodSyncAI] Loading vision model:", VISION_MODEL) | |
| _vision_pipe = pipeline( | |
| "image-classification", | |
| model=VISION_MODEL, | |
| device=DEVICE, | |
| top_k=None, | |
| ) | |
| return _vision_pipe | |
| def get_text_pipe(): | |
| global _text_pipe | |
| if _text_pipe is None: | |
| print("[MoodSyncAI] Loading text model:", TEXT_MODEL) | |
| _text_pipe = pipeline( | |
| "text-classification", | |
| model=TEXT_MODEL, | |
| device=DEVICE, | |
| top_k=None, | |
| truncation=True, | |
| ) | |
| return _text_pipe | |
| def get_generator(): | |
| global _gen_tokenizer, _gen_model | |
| if _gen_model is None: | |
| try: | |
| print("[MoodSyncAI] Loading generator:", GEN_MODEL) | |
| _gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL) | |
| _gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL) | |
| if DEVICE == 0: | |
| _gen_model = _gen_model.to("cuda") | |
| except Exception as e: | |
| print("[MoodSyncAI] Generator load failed, will use template fallback:", e) | |
| _gen_tokenizer, _gen_model = None, None | |
| return _gen_tokenizer, _gen_model | |
| def get_face_cascade(): | |
| global _face_cascade | |
| if _face_cascade is None: | |
| path = os.path.join(cv2.data.haarcascades, "haarcascade_frontalface_default.xml") | |
| _face_cascade = cv2.CascadeClassifier(path) | |
| return _face_cascade | |
| # ------------------------------------------------------------- | |
| # Valence map: used to align textual and visual signals | |
| # ------------------------------------------------------------- | |
| VALENCE = { | |
| # text emotions (from distilroberta) | |
| "joy": 1.0, | |
| "love": 1.0, | |
| "surprise": 0.3, | |
| "neutral": 0.0, | |
| "sadness": -1.0, | |
| "fear": -0.8, | |
| "anger": -0.9, | |
| "disgust": -0.8, | |
| # vision labels (ViT face expression labels) | |
| "happy": 1.0, | |
| "happiness": 1.0, | |
| "sad": -1.0, | |
| "angry": -0.9, | |
| "fearful": -0.8, | |
| "fear": -0.8, | |
| "disgusted": -0.8, | |
| "surprised": 0.3, | |
| "contempt": -0.6, | |
| } | |
| def valence_of(label: str) -> float: | |
| return VALENCE.get(label.lower().strip(), 0.0) | |
| # ------------------------------------------------------------- | |
| # Face detection (crops to face for better accuracy; falls back to full image) | |
| # ------------------------------------------------------------- | |
| def detect_and_crop_face(pil_img: Image.Image) -> Image.Image: | |
| try: | |
| cascade = get_face_cascade() | |
| rgb = np.array(pil_img.convert("RGB")) | |
| gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY) | |
| faces = cascade.detectMultiScale(gray, scaleFactor=1.2, minNeighbors=5, minSize=(60, 60)) | |
| if len(faces) == 0: | |
| return pil_img | |
| # Pick the largest face | |
| x, y, w, h = max(faces, key=lambda b: b[2] * b[3]) | |
| pad = int(0.15 * max(w, h)) | |
| x0 = max(0, x - pad); y0 = max(0, y - pad) | |
| x1 = min(rgb.shape[1], x + w + pad); y1 = min(rgb.shape[0], y + h + pad) | |
| return Image.fromarray(rgb[y0:y1, x0:x1]) | |
| except Exception: | |
| return pil_img | |
| # ------------------------------------------------------------- | |
| # Core analysis helpers | |
| # ------------------------------------------------------------- | |
| def predict_visual(pil_img: Image.Image) -> List[Dict]: | |
| pipe = get_vision_pipe() | |
| face = detect_and_crop_face(pil_img) | |
| preds = pipe(face) | |
| # normalise into list of {label,score} | |
| return [{"label": p["label"], "score": float(p["score"])} for p in preds] | |
| def predict_text(text: str) -> List[Dict]: | |
| if not text or not text.strip(): | |
| return [{"label": "neutral", "score": 1.0}] | |
| pipe = get_text_pipe() | |
| preds = pipe(text)[0] # top_k=None -> list of all | |
| return [{"label": p["label"], "score": float(p["score"])} for p in preds] | |
| def top1(preds: List[Dict]) -> Tuple[str, float]: | |
| p = max(preds, key=lambda d: d["score"]) | |
| return p["label"], p["score"] | |
| def weighted_valence(preds: List[Dict]) -> float: | |
| return sum(p["score"] * valence_of(p["label"]) for p in preds) | |
| def fuse(visual_preds: List[Dict], text_preds: List[Dict]) -> Dict: | |
| v_label, v_conf = top1(visual_preds) | |
| t_label, t_conf = top1(text_preds) | |
| v_val = weighted_valence(visual_preds) | |
| t_val = weighted_valence(text_preds) | |
| delta = v_val - t_val | |
| # mismatch: opposite sign with meaningful magnitude | |
| mismatch = (v_val * t_val < -0.05) or (abs(delta) > 0.9) | |
| if mismatch: | |
| status = "MISMATCH DETECTED" | |
| badge = "🟠" | |
| elif abs(delta) < 0.35: | |
| status = "ALIGNED" | |
| badge = "🟢" | |
| else: | |
| status = "PARTIALLY ALIGNED" | |
| badge = "🟡" | |
| # overall valence (weighted average favoring visual when mismatch) | |
| if mismatch: | |
| overall_val = 0.6 * v_val + 0.4 * t_val | |
| else: | |
| overall_val = 0.5 * (v_val + t_val) | |
| return { | |
| "visual_label": v_label, | |
| "visual_conf": v_conf, | |
| "text_label": t_label, | |
| "text_conf": t_conf, | |
| "visual_valence": v_val, | |
| "text_valence": t_val, | |
| "delta": delta, | |
| "status": status, | |
| "badge": badge, | |
| "overall_valence": overall_val, | |
| } | |
| # ------------------------------------------------------------- | |
| # Generative summary | |
| # ------------------------------------------------------------- | |
| def template_summary(fusion: Dict) -> str: | |
| v = fusion["visual_label"]; vc = fusion["visual_conf"] | |
| t = fusion["text_label"]; tc = fusion["text_conf"] | |
| if fusion["status"].startswith("MISMATCH"): | |
| return ( | |
| f"Despite expressing **{t}** sentiment verbally ({tc*100:.0f}% confidence), " | |
| f"the speaker's facial cues indicate **{v}** ({vc*100:.0f}% confidence). " | |
| f"This incongruence between words and expression is worth noting in the " | |
| f"context of the conversation - the spoken message may not fully reflect " | |
| f"how the person actually feels." | |
| ) | |
| if fusion["status"] == "ALIGNED": | |
| return ( | |
| f"The speaker's words ({t}, {tc*100:.0f}%) and facial expression " | |
| f"({v}, {vc*100:.0f}%) are consistent. The overall emotional state " | |
| f"appears genuine and uncomplicated." | |
| ) | |
| return ( | |
| f"The speaker shows mild divergence between facial expression ({v}, " | |
| f"{vc*100:.0f}%) and spoken sentiment ({t}, {tc*100:.0f}%). The signals " | |
| f"are not contradictory but suggest some nuance in the emotional state." | |
| ) | |
| def generative_summary(fusion: Dict, text_input: str) -> str: | |
| tok, model = get_generator() | |
| fallback = template_summary(fusion) | |
| if model is None or tok is None: | |
| return fallback | |
| try: | |
| mismatch = fusion["status"].startswith("MISMATCH") | |
| instr = ( | |
| "rewrite as one empathetic paragraph (2-3 sentences) that explicitly " | |
| "highlights the mismatch between facial expression and spoken words" | |
| if mismatch else | |
| "rewrite as one empathetic paragraph (2-3 sentences) noting the emotional state" | |
| ) | |
| prompt = ( | |
| f"You are an empathetic psychologist. Given the analysis below, {instr}. " | |
| f"Begin with the word 'The'.\n\n" | |
| f"Analysis:\n" | |
| f"- Spoken sentence: \"{text_input or '(none provided)'}\"\n" | |
| f"- Facial emotion detected: {fusion['visual_label']} " | |
| f"({fusion['visual_conf']*100:.0f}% confidence)\n" | |
| f"- Sentiment of the words: {fusion['text_label']} " | |
| f"({fusion['text_conf']*100:.0f}% confidence)\n" | |
| f"- Alignment: {fusion['status']}\n\n" | |
| f"Paragraph:" | |
| ) | |
| inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=512) | |
| if DEVICE == 0: | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=140, | |
| min_new_tokens=30, | |
| num_beams=4, | |
| do_sample=False, | |
| no_repeat_ngram_size=3, | |
| early_stopping=True, | |
| ) | |
| text = tok.decode(out[0], skip_special_tokens=True).strip() | |
| # Reject obvious echoes / too-short / off-topic outputs | |
| bad = (len(text) < 50 | |
| or text.lower().startswith(("tell ", "write ", "give ")) | |
| or "story" in text.lower()[:40] | |
| or fusion["visual_label"].lower() not in text.lower() | |
| and fusion["text_label"].lower() not in text.lower()) | |
| if bad: | |
| return fallback | |
| return text | |
| except Exception as e: | |
| print("[MoodSyncAI] Generation error:", e) | |
| return fallback | |
| # ------------------------------------------------------------- | |
| # Plotly charts | |
| # ------------------------------------------------------------- | |
| def bar_chart(preds: List[Dict], title: str, color: str) -> go.Figure: | |
| df = pd.DataFrame(preds).sort_values("score", ascending=True) | |
| df["pct"] = (df["score"] * 100).round(1) | |
| fig = go.Figure(go.Bar( | |
| x=df["pct"], y=df["label"], orientation="h", | |
| marker=dict(color=color), | |
| text=df["pct"].astype(str) + "%", | |
| textposition="outside", | |
| )) | |
| fig.update_layout( | |
| title=title, | |
| xaxis_title="Confidence (%)", | |
| yaxis_title=None, | |
| xaxis=dict(range=[0, 110]), | |
| height=320, margin=dict(l=10, r=10, t=40, b=10), | |
| template="plotly_white", | |
| ) | |
| return fig | |
| def empty_fig(msg="No data") -> go.Figure: | |
| fig = go.Figure() | |
| fig.add_annotation(text=msg, xref="paper", yref="paper", | |
| x=0.5, y=0.5, showarrow=False, font=dict(size=14)) | |
| fig.update_layout(height=320, template="plotly_white", | |
| margin=dict(l=10, r=10, t=20, b=10)) | |
| return fig | |
| # ------------------------------------------------------------- | |
| # Tab 1: Image + Text analysis | |
| # ------------------------------------------------------------- | |
| def analyse_image_text(image: Image.Image, text: str): | |
| if image is None: | |
| return (empty_fig("Please upload an image"), | |
| empty_fig("Awaiting input"), | |
| "### ⚠️ Please upload an image of a face.", "") | |
| visual_preds = predict_visual(image) | |
| text_preds = predict_text(text or "") | |
| fusion = fuse(visual_preds, text_preds) | |
| summary = generative_summary(fusion, text) | |
| vfig = bar_chart(visual_preds, "👁️ Visual Emotion (ViT)", "#4C78A8") | |
| tfig = bar_chart(text_preds, "💬 Text Sentiment (Transformer)", "#54A24B") | |
| fusion_md = f""" | |
| ### {fusion['badge']} Fusion Result: **{fusion['status']}** | |
| | Modality | Top Prediction | Confidence | Valence | | |
| |---|---|---|---| | |
| | 👁️ Visual | **{fusion['visual_label']}** | {fusion['visual_conf']*100:.1f}% | {fusion['visual_valence']:+.2f} | | |
| | 💬 Text | **{fusion['text_label']}** | {fusion['text_conf']*100:.1f}% | {fusion['text_valence']:+.2f} | | |
| | 🔗 Overall valence | — | — | **{fusion['overall_valence']:+.2f}** | | |
| """ | |
| summary_md = f"### 🧠 Generative Summary\n\n> {summary}" | |
| return vfig, tfig, fusion_md, summary_md | |
| # ------------------------------------------------------------- | |
| # Tab 2: Webcam / short video → emotion timeline | |
| # ------------------------------------------------------------- | |
| def sample_frames(video_path: str, max_frames: int = 12) -> List[Tuple[float, Image.Image]]: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| return [] | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 | |
| total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) | |
| # If total frames is unknown, read sequentially to count. | |
| if total <= 0: | |
| total = 0 | |
| while True: | |
| ok, _ = cap.read() | |
| if not ok: | |
| break | |
| total += 1 | |
| cap.release() | |
| cap = cv2.VideoCapture(video_path) | |
| if total <= 0: | |
| return [] | |
| duration = total / fps if fps > 0 else 1.0 | |
| n = min(max_frames, max(3, int(duration * 2))) # ~2 fps target | |
| target_idxs = set(np.linspace(0, total - 1, n).astype(int).tolist()) | |
| out: List[Tuple[float, Image.Image]] = [] | |
| idx = 0 | |
| while True: | |
| ok, frame = cap.read() | |
| if not ok: | |
| break | |
| if idx in target_idxs: | |
| ts = idx / fps if fps > 0 else float(idx) | |
| pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| out.append((float(ts), pil)) | |
| if len(out) >= n: | |
| break | |
| idx += 1 | |
| cap.release() | |
| return out | |
| def analyse_video_text(video_path, text: str): | |
| if not video_path: | |
| return (empty_fig("Record or upload a short video"), | |
| empty_fig("Awaiting input"), | |
| empty_fig("Awaiting input"), | |
| "### ⚠️ Please provide a webcam video.", "") | |
| frames = sample_frames(video_path, max_frames=12) | |
| if not frames: | |
| return (empty_fig("Could not read video"), | |
| empty_fig(""), empty_fig(""), | |
| "### ⚠️ Could not decode the video file.", "") | |
| timeline = [] # list of dict: ts, label->score | |
| aggregated: Dict[str, float] = {} | |
| for ts, pil in frames: | |
| preds = predict_visual(pil) | |
| row = {"timestamp": ts} | |
| for p in preds: | |
| row[p["label"]] = p["score"] | |
| aggregated[p["label"]] = aggregated.get(p["label"], 0.0) + p["score"] | |
| timeline.append(row) | |
| # Average the aggregated visual prediction across frames | |
| n = len(frames) | |
| avg_visual = [{"label": k, "score": v / n} for k, v in aggregated.items()] | |
| text_preds = predict_text(text or "") | |
| fusion = fuse(avg_visual, text_preds) | |
| summary = generative_summary(fusion, text) | |
| # Timeline figure (line per emotion) | |
| df = pd.DataFrame(timeline).fillna(0.0) | |
| label_cols = [c for c in df.columns if c != "timestamp"] | |
| tl_fig = go.Figure() | |
| palette = px.colors.qualitative.Set2 | |
| for i, lbl in enumerate(label_cols): | |
| tl_fig.add_trace(go.Scatter( | |
| x=df["timestamp"], y=df[lbl] * 100, | |
| mode="lines+markers", name=lbl, | |
| line=dict(color=palette[i % len(palette)], width=2), | |
| )) | |
| tl_fig.update_layout( | |
| title="📈 Emotion Timeline (per frame)", | |
| xaxis_title="Time (s)", yaxis_title="Confidence (%)", | |
| height=360, template="plotly_white", | |
| margin=dict(l=10, r=10, t=40, b=10), | |
| yaxis=dict(range=[0, 100]), | |
| ) | |
| vfig = bar_chart(avg_visual, "👁️ Average Visual Emotion", "#4C78A8") | |
| tfig = bar_chart(text_preds, "💬 Text Sentiment", "#54A24B") | |
| fusion_md = f""" | |
| ### {fusion['badge']} Fusion Result: **{fusion['status']}** | |
| | Modality | Top Prediction | Confidence | Valence | | |
| |---|---|---|---| | |
| | 👁️ Visual (avg) | **{fusion['visual_label']}** | {fusion['visual_conf']*100:.1f}% | {fusion['visual_valence']:+.2f} | | |
| | 💬 Text | **{fusion['text_label']}** | {fusion['text_conf']*100:.1f}% | {fusion['text_valence']:+.2f} | | |
| | 🔗 Overall valence | — | — | **{fusion['overall_valence']:+.2f}** | | |
| *Analysed {n} frames from the video.* | |
| """ | |
| summary_md = f"### 🧠 Generative Summary\n\n> {summary}" | |
| return tl_fig, vfig, tfig, fusion_md, summary_md | |
| # ============================================================= | |
| # NEW FEATURE BLOCK (additive — does not touch Tab 1 / Tab 2) | |
| # ============================================================= | |
| # 1) Whisper ASR (audio → text channel) | |
| # 2) Video with audio (transcribe + frame timeline + fusion) | |
| # 3) Attention visualisation (ViT rollout heatmap + text token attention) | |
| # ============================================================= | |
| import tempfile | |
| import subprocess | |
| import html as _html | |
| def get_asr_pipe(): | |
| global _asr_pipe | |
| if _asr_pipe is None: | |
| print("[MoodSyncAI] Loading ASR model:", ASR_MODEL) | |
| _asr_pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model=ASR_MODEL, | |
| device=DEVICE, | |
| chunk_length_s=30, | |
| return_timestamps=False, | |
| ) | |
| return _asr_pipe | |
| def transcribe_audio(audio_path: str) -> str: | |
| if not audio_path: | |
| return "" | |
| try: | |
| # Load audio ourselves (soundfile/librosa) so we don't depend on | |
| # whisper's internal ffmpeg-via-PATH lookup. | |
| import soundfile as sf | |
| try: | |
| audio, sr = sf.read(audio_path, dtype="float32", always_2d=False) | |
| except Exception: | |
| import librosa | |
| audio, sr = librosa.load(audio_path, sr=16000, mono=True) | |
| if audio.ndim > 1: | |
| audio = audio.mean(axis=1) | |
| if sr != 16000: | |
| import librosa | |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) | |
| sr = 16000 | |
| if audio.size == 0: | |
| return "" | |
| pipe = get_asr_pipe() | |
| out = pipe( | |
| {"array": audio, "sampling_rate": sr}, | |
| generate_kwargs={"language": "en", "task": "transcribe"}, | |
| ) | |
| text = out.get("text", "") if isinstance(out, dict) else str(out) | |
| return (text or "").strip() | |
| except Exception as e: | |
| print("[MoodSyncAI] Transcription error:", e) | |
| return "" | |
| def _ffmpeg_exe() -> str: | |
| try: | |
| import imageio_ffmpeg | |
| return imageio_ffmpeg.get_ffmpeg_exe() | |
| except Exception: | |
| return "ffmpeg" | |
| def extract_audio_from_video(video_path: str) -> str: | |
| """Extract mono 16 kHz wav from video. Returns wav path or '' on failure.""" | |
| if not video_path: | |
| return "" | |
| try: | |
| out_path = tempfile.NamedTemporaryFile( | |
| suffix=".wav", delete=False | |
| ).name | |
| cmd = [ | |
| _ffmpeg_exe(), "-y", "-i", video_path, | |
| "-vn", "-ac", "1", "-ar", "16000", | |
| "-f", "wav", out_path, | |
| ] | |
| proc = subprocess.run(cmd, capture_output=True, timeout=120) | |
| if proc.returncode != 0 or not os.path.exists(out_path) or os.path.getsize(out_path) < 1024: | |
| return "" | |
| return out_path | |
| except Exception as e: | |
| print("[MoodSyncAI] Audio-extract error:", e) | |
| return "" | |
| # ------------------------------------------------------------- | |
| # Attention visualisation | |
| # ------------------------------------------------------------- | |
| def _get_vit_attn(): | |
| global _vit_attn_model, _vit_attn_processor | |
| if _vit_attn_model is None: | |
| print("[MoodSyncAI] Loading ViT (eager attn) for attention rollout") | |
| _vit_attn_processor = AutoImageProcessor.from_pretrained(VISION_MODEL) | |
| _vit_attn_model = AutoModelForImageClassification.from_pretrained( | |
| VISION_MODEL, attn_implementation="eager" | |
| ) | |
| _vit_attn_model.eval() | |
| if DEVICE == 0: | |
| _vit_attn_model = _vit_attn_model.to("cuda") | |
| return _vit_attn_model, _vit_attn_processor | |
| def _get_text_attn(): | |
| global _text_attn_model, _text_attn_tokenizer | |
| if _text_attn_model is None: | |
| print("[MoodSyncAI] Loading text classifier (eager attn) for token attention") | |
| _text_attn_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL) | |
| _text_attn_model = AutoModelForSequenceClassification.from_pretrained( | |
| TEXT_MODEL, attn_implementation="eager" | |
| ) | |
| _text_attn_model.eval() | |
| if DEVICE == 0: | |
| _text_attn_model = _text_attn_model.to("cuda") | |
| return _text_attn_model, _text_attn_tokenizer | |
| def vit_attention_heatmap(pil_img: Image.Image) -> Image.Image: | |
| """Attention-rollout heatmap overlaid on the (face-cropped) image.""" | |
| try: | |
| face = detect_and_crop_face(pil_img).convert("RGB") | |
| model, processor = _get_vit_attn() | |
| inputs = processor(images=face, return_tensors="pt") | |
| if DEVICE == 0: | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| out = model(**inputs, output_attentions=True) | |
| attns = out.attentions # tuple(L) of (1, H, S, S) | |
| if not attns: | |
| return face | |
| # Attention rollout: avg heads, add identity, normalise, multiply layers | |
| result = None | |
| for a in attns: | |
| a = a.mean(dim=1).squeeze(0) # (S, S) | |
| a = a + torch.eye(a.size(0), device=a.device) | |
| a = a / a.sum(dim=-1, keepdim=True) | |
| result = a if result is None else a @ result | |
| # CLS-token row, drop CLS index → patch importances | |
| cls_attn = result[0, 1:].detach().cpu().numpy() | |
| side = int(np.sqrt(cls_attn.shape[0])) | |
| if side * side != cls_attn.shape[0]: | |
| return face | |
| grid = cls_attn.reshape(side, side) | |
| grid = (grid - grid.min()) / (grid.max() - grid.min() + 1e-8) | |
| # Resize heatmap to face image | |
| w, h = face.size | |
| heat = cv2.resize(grid, (w, h), interpolation=cv2.INTER_CUBIC) | |
| heat_u8 = (heat * 255).astype(np.uint8) | |
| color = cv2.applyColorMap(heat_u8, cv2.COLORMAP_JET) | |
| color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB) | |
| base = np.array(face) | |
| overlay = (0.55 * base + 0.45 * color).clip(0, 255).astype(np.uint8) | |
| return Image.fromarray(overlay) | |
| except Exception as e: | |
| print("[MoodSyncAI] ViT attention error:", e) | |
| return pil_img | |
| def text_token_attention_html(text: str) -> str: | |
| """Render input text with per-token attention intensity (last layer, [CLS] row).""" | |
| if not text or not text.strip(): | |
| return "<em>(no text)</em>" | |
| try: | |
| model, tok = _get_text_attn() | |
| enc = tok(text, return_tensors="pt", truncation=True, max_length=256) | |
| if DEVICE == 0: | |
| enc = {k: v.to("cuda") for k, v in enc.items()} | |
| with torch.no_grad(): | |
| out = model(**enc, output_attentions=True) | |
| attns = out.attentions # tuple(L) of (1, H, S, S) | |
| if not attns: | |
| return _html.escape(text) | |
| last = attns[-1].mean(dim=1).squeeze(0) # (S, S) | |
| cls_row = last[0].detach().cpu().numpy() # importance of each token to CLS | |
| ids = enc["input_ids"][0].detach().cpu().tolist() | |
| tokens = tok.convert_ids_to_tokens(ids) | |
| # Skip special tokens for normalisation range | |
| specials = set(tok.all_special_tokens) | |
| keep_mask = np.array([t not in specials for t in tokens]) | |
| if keep_mask.sum() == 0: | |
| return _html.escape(text) | |
| scores = cls_row.copy() | |
| scores_disp = scores[keep_mask] | |
| lo, hi = scores_disp.min(), scores_disp.max() | |
| norm = (scores - lo) / (hi - lo + 1e-8) | |
| norm = np.clip(norm, 0.0, 1.0) | |
| # Build HTML: merge subword tokens (RoBERTa uses 'Ġ' prefix for word start) | |
| spans = [] | |
| for i, t in enumerate(tokens): | |
| if t in specials: | |
| continue | |
| display = t | |
| prefix_space = "" | |
| if display.startswith("Ġ"): | |
| display = display[1:] | |
| prefix_space = " " | |
| elif display.startswith("▁"): | |
| display = display[1:] | |
| prefix_space = " " | |
| intensity = float(norm[i]) | |
| # red highlight, alpha from intensity | |
| bg = f"rgba(220,38,38,{intensity:.2f})" | |
| color = "#fff" if intensity > 0.55 else "#111" | |
| safe = _html.escape(display) | |
| spans.append( | |
| f"{prefix_space}<span style=\"background:{bg};color:{color};" | |
| f"padding:2px 4px;border-radius:4px;margin:1px;" | |
| f"font-family:monospace\" title=\"{intensity:.2f}\">{safe}</span>" | |
| ) | |
| body = "".join(spans).strip() | |
| legend = ( | |
| "<div style='margin-top:8px;font-size:12px;color:#555'>" | |
| "Darker red = higher attention weight from [CLS] to that token " | |
| "(last transformer layer, averaged over heads)." | |
| "</div>" | |
| ) | |
| return f"<div style='line-height:2;font-size:15px'>{body}</div>{legend}" | |
| except Exception as e: | |
| print("[MoodSyncAI] Text attention error:", e) | |
| return _html.escape(text) | |
| # ------------------------------------------------------------- | |
| # Tab 1 wrapper: existing outputs + (optional) attention viz | |
| # ------------------------------------------------------------- | |
| def analyse_image_text_with_attention(image: Image.Image, text: str, show_attn: bool): | |
| vfig, tfig, fusion_md, summary_md = analyse_image_text(image, text) | |
| if not show_attn or image is None: | |
| return (vfig, tfig, fusion_md, summary_md, | |
| None, "<em>Toggle 'Show attention visualisation' to view.</em>") | |
| heat = vit_attention_heatmap(image) | |
| token_html = text_token_attention_html(text or "") | |
| return vfig, tfig, fusion_md, summary_md, heat, token_html | |
| # ------------------------------------------------------------- | |
| # Tab 3: Audio + Image | |
| # ------------------------------------------------------------- | |
| def analyse_audio_image(audio_path, image: Image.Image): | |
| if image is None and not audio_path: | |
| return ("", | |
| empty_fig("Provide an image"), | |
| empty_fig("Provide audio"), | |
| "### ⚠️ Please provide both an image and audio.", "") | |
| transcript = transcribe_audio(audio_path) if audio_path else "" | |
| if not transcript: | |
| transcript = "(no speech detected)" | |
| if image is None: | |
| return (transcript, | |
| empty_fig("No image provided"), | |
| empty_fig("(transcript only)"), | |
| "### ⚠️ Please also provide a face image.", "") | |
| visual_preds = predict_visual(image) | |
| spoken = "" if transcript.startswith("(") else transcript | |
| text_preds = predict_text(spoken) | |
| fusion = fuse(visual_preds, text_preds) | |
| summary = generative_summary(fusion, spoken) | |
| vfig = bar_chart(visual_preds, "👁️ Visual Emotion (ViT)", "#4C78A8") | |
| tfig = bar_chart(text_preds, "💬 Sentiment of Transcribed Speech", "#54A24B") | |
| fusion_md = f""" | |
| ### {fusion['badge']} Fusion Result: **{fusion['status']}** | |
| | Modality | Top Prediction | Confidence | Valence | | |
| |---|---|---|---| | |
| | 👁️ Visual (image) | **{fusion['visual_label']}** | {fusion['visual_conf']*100:.1f}% | {fusion['visual_valence']:+.2f} | | |
| | 🎙️ Audio → Text | **{fusion['text_label']}** | {fusion['text_conf']*100:.1f}% | {fusion['text_valence']:+.2f} | | |
| | 🔗 Overall valence | — | — | **{fusion['overall_valence']:+.2f}** | | |
| """ | |
| summary_md = f"### 🧠 Generative Summary\n\n> {summary}" | |
| return transcript, vfig, tfig, fusion_md, summary_md | |
| # ------------------------------------------------------------- | |
| # Tab 4: Video WITH audio (frames timeline + audio transcript → text channel) | |
| # ------------------------------------------------------------- | |
| def analyse_video_with_audio(video_path): | |
| if not video_path: | |
| return ("", | |
| empty_fig("Record or upload a video"), | |
| empty_fig(""), empty_fig(""), | |
| "### ⚠️ Please provide a video.", "") | |
| frames = sample_frames(video_path, max_frames=12) | |
| if not frames: | |
| return ("", | |
| empty_fig("Could not read video"), | |
| empty_fig(""), empty_fig(""), | |
| "### ⚠️ Could not decode the video file.", "") | |
| # 1) Audio → transcript | |
| wav = extract_audio_from_video(video_path) | |
| transcript = transcribe_audio(wav) if wav else "" | |
| if wav and os.path.exists(wav): | |
| try: os.remove(wav) | |
| except Exception: pass | |
| if not transcript: | |
| transcript = "(no speech detected in the audio track)" | |
| spoken = "" if transcript.startswith("(") else transcript | |
| # 2) Per-frame visual + aggregate | |
| timeline = [] | |
| aggregated: Dict[str, float] = {} | |
| for ts, pil in frames: | |
| preds = predict_visual(pil) | |
| row = {"timestamp": ts} | |
| for p in preds: | |
| row[p["label"]] = p["score"] | |
| aggregated[p["label"]] = aggregated.get(p["label"], 0.0) + p["score"] | |
| timeline.append(row) | |
| n = len(frames) | |
| avg_visual = [{"label": k, "score": v / n} for k, v in aggregated.items()] | |
| # 3) Text channel from transcript | |
| text_preds = predict_text(spoken) | |
| fusion = fuse(avg_visual, text_preds) | |
| summary = generative_summary(fusion, spoken) | |
| # Timeline figure | |
| df = pd.DataFrame(timeline).fillna(0.0) | |
| label_cols = [c for c in df.columns if c != "timestamp"] | |
| tl_fig = go.Figure() | |
| palette = px.colors.qualitative.Set2 | |
| for i, lbl in enumerate(label_cols): | |
| tl_fig.add_trace(go.Scatter( | |
| x=df["timestamp"], y=df[lbl] * 100, | |
| mode="lines+markers", name=lbl, | |
| line=dict(color=palette[i % len(palette)], width=2), | |
| )) | |
| tl_fig.update_layout( | |
| title="📈 Emotion Timeline (per frame) — audio transcript drives text channel", | |
| xaxis_title="Time (s)", yaxis_title="Confidence (%)", | |
| height=360, template="plotly_white", | |
| margin=dict(l=10, r=10, t=40, b=10), | |
| yaxis=dict(range=[0, 100]), | |
| ) | |
| vfig = bar_chart(avg_visual, "👁️ Avg Visual Emotion (frames)", "#4C78A8") | |
| tfig = bar_chart(text_preds, "💬 Sentiment of Spoken Audio", "#54A24B") | |
| fusion_md = f""" | |
| ### {fusion['badge']} Fusion Result: **{fusion['status']}** | |
| | Modality | Top Prediction | Confidence | Valence | | |
| |---|---|---|---| | |
| | 👁️ Visual (avg of {n} frames) | **{fusion['visual_label']}** | {fusion['visual_conf']*100:.1f}% | {fusion['visual_valence']:+.2f} | | |
| | 🎙️ Audio transcript | **{fusion['text_label']}** | {fusion['text_conf']*100:.1f}% | {fusion['text_valence']:+.2f} | | |
| | 🔗 Overall valence | — | — | **{fusion['overall_valence']:+.2f}** | | |
| *Spoken words (auto-transcribed):* "{spoken or '—'}" | |
| """ | |
| summary_md = f"### 🧠 Generative Summary\n\n> {summary}" | |
| return transcript, tl_fig, vfig, tfig, fusion_md, summary_md | |
| # ------------------------------------------------------------- | |
| # Gradio UI | |
| # ------------------------------------------------------------- | |
| CSS = """ | |
| .gradio-container {max-width: 1200px !important;} | |
| #title {text-align:center;} | |
| footer {display: none !important;} | |
| .show-api, .built-with, .settings {display: none !important;} | |
| """ | |
| with gr.Blocks(title="MoodSyncAI", theme=gr.themes.Soft(), css=CSS) as demo: | |
| gr.Markdown("# 🎭 MoodSyncAI", elem_id="title") | |
| gr.Markdown( | |
| "**Multi-Modal Sentiment & Emotion Analyser** — combines a Vision " | |
| "Transformer (face), a Transformer text classifier (words), a fusion " | |
| "layer (mismatch detection), and a generative model (plain-language " | |
| "summary). 100% open-source." | |
| ) | |
| with gr.Tabs(): | |
| # ---------------- Tab 1 ---------------- | |
| with gr.Tab("🖼️ Image + Text"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img_in = gr.Image(type="pil", label="Face photo", height=320) | |
| txt_in = gr.Textbox( | |
| label="What the person said", | |
| placeholder="e.g., No, I think the project is going really well.", | |
| lines=2, | |
| ) | |
| btn1 = gr.Button("🔍 Analyse", variant="primary") | |
| attn_toggle1 = gr.Checkbox( | |
| label="🔬 Show attention visualisation (ViT rollout + text tokens)", | |
| value=False, | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| [None, "No, I think the project is going really well."], | |
| [None, "I'm absolutely thrilled about the results!"], | |
| [None, "I'm fine, really, don't worry about me."], | |
| ], | |
| inputs=[img_in, txt_in], | |
| ) | |
| with gr.Column(scale=2): | |
| fusion_md1 = gr.Markdown() | |
| summary_md1 = gr.Markdown() | |
| with gr.Row(): | |
| vbar1 = gr.Plot(label="Visual emotion") | |
| tbar1 = gr.Plot(label="Text sentiment") | |
| with gr.Accordion("🔬 Attention visualisation", open=False): | |
| attn_img1 = gr.Image( | |
| label="ViT attention rollout (face)", | |
| height=320, interactive=False, | |
| ) | |
| attn_html1 = gr.HTML(label="Text token attention") | |
| btn1.click(analyse_image_text_with_attention, | |
| inputs=[img_in, txt_in, attn_toggle1], | |
| outputs=[vbar1, tbar1, fusion_md1, summary_md1, | |
| attn_img1, attn_html1]) | |
| # ---------------- Tab 2 ---------------- | |
| with gr.Tab("📹 Webcam / Video + Text"): | |
| gr.Markdown( | |
| "Record a short clip from your webcam (3–10 s recommended) **or** " | |
| "upload a short video. The system samples frames and builds an " | |
| "emotion timeline." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| vid_in = gr.Video( | |
| label="Webcam / video", | |
| sources=["webcam", "upload"], | |
| height=300, | |
| ) | |
| txt_in2 = gr.Textbox( | |
| label="What the person said", | |
| placeholder="Type the spoken sentence here…", | |
| lines=2, | |
| ) | |
| btn2 = gr.Button("🔍 Analyse video", variant="primary") | |
| with gr.Column(scale=2): | |
| timeline_plot = gr.Plot(label="Emotion timeline") | |
| fusion_md2 = gr.Markdown() | |
| summary_md2 = gr.Markdown() | |
| with gr.Row(): | |
| vbar2 = gr.Plot(label="Avg visual emotion") | |
| tbar2 = gr.Plot(label="Text sentiment") | |
| btn2.click(analyse_video_text, | |
| inputs=[vid_in, txt_in2], | |
| outputs=[timeline_plot, vbar2, tbar2, fusion_md2, summary_md2]) | |
| # ---------------- Tab 3 : Audio + Image ---------------- | |
| with gr.Tab("🎙️ Audio + Image"): | |
| gr.Markdown( | |
| "Speak (or upload audio) **and** provide a face image. Whisper " | |
| "transcribes the audio; the words become the *text channel* fed " | |
| "into the multimodal fusion." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| audio_in3 = gr.Audio( | |
| label="🎙️ Audio (microphone or upload)", | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| ) | |
| img_in3 = gr.Image(type="pil", label="Face photo", height=300) | |
| btn3 = gr.Button("🔍 Transcribe & analyse", variant="primary") | |
| with gr.Column(scale=2): | |
| transcript3 = gr.Textbox( | |
| label="Auto-transcript (Whisper)", | |
| interactive=False, lines=2, | |
| ) | |
| fusion_md3 = gr.Markdown() | |
| summary_md3 = gr.Markdown() | |
| with gr.Row(): | |
| vbar3 = gr.Plot(label="Visual emotion") | |
| tbar3 = gr.Plot(label="Audio→text sentiment") | |
| btn3.click(analyse_audio_image, | |
| inputs=[audio_in3, img_in3], | |
| outputs=[transcript3, vbar3, tbar3, fusion_md3, summary_md3]) | |
| # ---------------- Tab 4 : Video WITH audio ---------------- | |
| with gr.Tab("🎬 Video with Audio"): | |
| gr.Markdown( | |
| "Record or upload a short video **with sound**. The system extracts " | |
| "the audio track, transcribes it (Whisper), samples frames for an " | |
| "emotion timeline, then fuses the visual signal with the spoken-word " | |
| "sentiment — no manual typing needed." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| vid_in4 = gr.Video( | |
| label="Webcam / video (with audio)", | |
| sources=["webcam", "upload"], | |
| height=300, | |
| ) | |
| btn4 = gr.Button("🔍 Transcribe & analyse video", variant="primary") | |
| with gr.Column(scale=2): | |
| transcript4 = gr.Textbox( | |
| label="Auto-transcript (Whisper)", | |
| interactive=False, lines=2, | |
| ) | |
| timeline_plot4 = gr.Plot(label="Emotion timeline") | |
| fusion_md4 = gr.Markdown() | |
| summary_md4 = gr.Markdown() | |
| with gr.Row(): | |
| vbar4 = gr.Plot(label="Avg visual emotion") | |
| tbar4 = gr.Plot(label="Audio→text sentiment") | |
| btn4.click(analyse_video_with_audio, | |
| inputs=[vid_in4], | |
| outputs=[transcript4, timeline_plot4, vbar4, tbar4, | |
| fusion_md4, summary_md4]) | |
| # ---------------- Tab 3 (about) ---------------- | |
| with gr.Tab("ℹ️ About"): | |
| gr.Markdown(f""" | |
| ### Architecture | |
| | Stage | Model | Type | | |
| |---|---|---| | |
| | Visual emotion | `{VISION_MODEL}` | **Vision Transformer (ViT)** | | |
| | Text sentiment | `{TEXT_MODEL}` | **Transformer (DistilRoBERTa)** | | |
| | Speech-to-text | `{ASR_MODEL}` | **Encoder-Decoder Transformer (Whisper)** | | |
| | Fusion | Valence-aligned multimodal fusion (custom) | rule + weighted | | |
| | Generative summary | `{GEN_MODEL}` | **Encoder-Decoder Transformer (FLAN-T5)** | | |
| | Attention viz | ViT attention rollout + last-layer text attention | interpretability | | |
| ### Fusion logic | |
| 1. Each modality produces a probability distribution over emotion labels. | |
| 2. Labels are mapped to a *valence* score in `[-1, +1]`. | |
| 3. We compute weighted valence per modality, then a delta. | |
| 4. Opposite signs → **MISMATCH** (amber). Small delta → **ALIGNED** (green). | |
| 5. Generative model receives the structured signals and writes plain-language output. | |
| ### Privacy | |
| All processing runs locally on your machine; no data is sent to external services | |
| after the first model download from the Hugging Face Hub. | |
| """) | |
| if __name__ == "__main__": | |
| # Warm up small models so first request is snappy | |
| try: | |
| get_text_pipe() | |
| except Exception as e: | |
| print("[MoodSyncAI] Warmup text failed:", e) | |
| import os as _os | |
| _on_spaces = bool(_os.environ.get("SPACE_ID")) | |
| demo.queue().launch( | |
| server_name="0.0.0.0" if _on_spaces else "127.0.0.1", | |
| server_port=7860, | |
| inbrowser=not _on_spaces, | |
| show_error=True, | |
| show_api=False, | |
| ssr_mode=False, | |
| ) | |