File size: 4,101 Bytes
6835659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from __future__ import annotations

from typing import Optional

import numpy as np

from src.embeddings.audio_embedder import AudioEmbedder
from src.embeddings.image_embedder import ImageEmbedder
from src.embeddings.projection import ProjectionHead
from src.embeddings.text_embedder import TextEmbedder
from src.utils.cache import EmbeddingCache


class AlignedEmbedder:
    """
    Cross-modal embedding with correct space alignment.

    Two pre-trained shared spaces are used:
    - CLIP: text ↔ image (both from openai/clip-vit-base-patch32, 512-d)
    - CLAP: text ↔ audio (both from laion/clap-htsat-unfused, 512-d)

    For text-image similarity: use embed_text() and embed_image()
    For text-audio similarity: use embed_text_for_audio() and embed_audio()
    For image-audio similarity: these are cross-space (CLIP vs CLAP) β€”
        no meaningful direct comparison without a trained bridge.

    ProjectionHead is identity when in_dim == out_dim (preserving pre-trained
    alignment). Only applies a linear transformation when dimensions differ.
    """

    def __init__(
        self,
        target_dim: int = 512,
        enable_cache: bool = True,
        cache_dir: str = ".cache/embeddings",
    ):
        self.text = TextEmbedder()       # CLIP text encoder
        self.image = ImageEmbedder()     # CLIP image encoder
        self.audio = AudioEmbedder()     # CLAP audio encoder (also has text)

        # Identity projections when dims match (512 β†’ 512)
        self.text_proj = ProjectionHead(512, target_dim)
        self.image_proj = ProjectionHead(512, target_dim)
        self.audio_proj = ProjectionHead(512, target_dim)

        self.cache: Optional[EmbeddingCache] = None
        if enable_cache:
            self.cache = EmbeddingCache(cache_dir=cache_dir)

    def embed_text(self, text: str) -> np.ndarray:
        """CLIP text embedding β€” use for text-image comparison."""
        if self.cache:
            cached = self.cache.get(text, "text")
            if cached is not None:
                return cached

        emb = self.text.embed(text)
        projected = self.text_proj.project(emb)

        if self.cache:
            self.cache.set(text, "text", projected)

        return projected

    def embed_text_for_audio(self, text: str) -> np.ndarray:
        """CLAP text embedding β€” use for text-audio comparison."""
        if self.cache:
            cached = self.cache.get(text, "text_clap")
            if cached is not None:
                return cached

        emb = self.audio.embed_text(text)
        projected = self.audio_proj.project(emb)

        if self.cache:
            self.cache.set(text, "text_clap", projected)

        return projected

    def embed_image(self, path: str) -> np.ndarray:
        """CLIP image embedding β€” use for text-image comparison."""
        if self.cache:
            cached = self.cache.get(path, "image")
            if cached is not None:
                return cached

        emb = self.image.embed(path)
        projected = self.image_proj.project(emb)

        if self.cache:
            self.cache.set(path, "image", projected)

        return projected

    def embed_audio(self, path: str) -> np.ndarray:
        """CLAP audio embedding β€” use for text-audio comparison."""
        if self.cache:
            cached = self.cache.get(path, "audio")
            if cached is not None:
                return cached

        emb = self.audio.embed(path)
        projected = self.audio_proj.project(emb)

        if self.cache:
            self.cache.set(path, "audio", projected)

        return projected

    @staticmethod
    def shared(
        target_dim: int = 512,
        enable_cache: bool = True,
        cache_dir: str = ".cache/embeddings",
    ) -> "AlignedEmbedder":
        """
        Get the thread-safe shared singleton instance.

        Use this in parallel experiments to avoid loading CLIP+CLAP per thread.
        """
        from src.embeddings.shared_embedder import get_shared_embedder
        return get_shared_embedder(target_dim, enable_cache, cache_dir)