File size: 4,305 Bytes
bc0d37c
 
 
29899b4
bc0d37c
29899b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc0d37c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29899b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc0d37c
29899b4
 
 
 
 
 
bc0d37c
29899b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
Utility funcitons related to creating embeddings
"""

import numpy as np
from pathlib import Path
import pickle
import shelve
import rootutils
from dpacman.utils import pylogger

root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
logger = pylogger.RankedLogger(__name__, rank_zero_only=True)


def pkl_to_shelf(pkl_path: str, shelf_path: str):
    # WARNING: this will load the original pickle once.
    with open(pkl_path, "rb") as f:
        data = pickle.load(f)  # {sequence_str: np.ndarray or list}
    with shelve.open(shelf_path, flag="n", writeback=False) as db:
        for k, v in data.items():
            arr = np.asarray(v)  # ensure ndarray
            db[str(k)] = arr


def pad_token_embeddings(list_of_arrays, pad_value=0.0):
    """
    list_of_arrays: list of (L_i, D) numpy arrays
    Returns:
      padded: (N, L_max, D) array
      mask:   (N, L_max) boolean array where True = real token, False = padding
    """
    N = len(list_of_arrays)
    D = list_of_arrays[0].shape[1]
    L_max = max(arr.shape[0] for arr in list_of_arrays)
    padded = np.full((N, L_max, D), pad_value, dtype=list_of_arrays[0].dtype)
    mask = np.zeros((N, L_max), dtype=bool)
    for i, arr in enumerate(list_of_arrays):
        L = arr.shape[0]
        padded[i, :L] = arr
        mask[i, :L] = True
    return padded, mask


def _to_numpy(x):
    """Best-effort: convert torch.Tensor or arraylikes to np.ndarray (CPU)."""
    try:
        import torch

        if isinstance(x, torch.Tensor):
            return x.detach().cpu().numpy()
    except Exception:
        pass
    return np.asarray(x)


def embed_and_save(
    seqs, ids, embedder, out_path, batch_size=1, save_as_shelf: bool = True
):
    """
    Using the passed embedder, make embeddings and store as a pickle mapping:
        {sequence (str): embedding (np.ndarray)}

    Notes:
      - If multiple entries share the exact same sequence string, the *last one wins* in the embedder.
      - Validates that every requested sequence has an embedding.
    """
    out_path = Path(out_path)
    pkl_path = out_path.with_suffix(".pkl")

    # 1) Run the embedder (expects dict: {seq: embedding})
    embs_dict = embedder.embed(seqs, batch_size=batch_size)
    if not isinstance(embs_dict, dict):
        raise TypeError(f"Expected dict from embedder.embed, got {type(embs_dict)}")

    # 2) Detect duplicates in the input order
    seen, dupes = set(), 0
    for s in seqs:
        if s in seen:
            dupes += 1
        seen.add(s)
    if dupes:
        msg = (
            f"[embed_and_save] Warning: {dupes} duplicate sequence(s) in input; "
            f"pickle will contain one entry per unique sequence."
        )
        try:
            (logger.info if logger is not None else print)(msg)
        except Exception:
            print(msg, flush=True)

    # 3) Build ordered mapping (respect input order; last occurrence already reflected in embs_dict)
    mapping = {}
    missing = []
    for s in seqs:
        if s in mapping:
            continue  # already stored (keep one per unique sequence)
        e = embs_dict.get(s)
        if e is None:
            missing.append(s)
            continue
        mapping[s] = _to_numpy(e)

    if missing:
        raise KeyError(
            f"Embedder did not return embeddings for {len(missing)} sequence(s). "
            f"Example: {missing[0][:50]}..."
        )

    # 4) Save pickle
    with open(pkl_path, "wb") as f:
        pickle.dump(mapping, f, protocol=5)
        logger.info(f"Saved as pkl at {pkl_path}")

    # 5) Optional tiny manifest
    try:
        n = len(mapping)
        ndims = [v.ndim for v in mapping.values()]
        n_vec = sum(d == 1 for d in ndims)  # pooled (D,)
        n_tok = sum(d == 2 for d in ndims)  # per-token (L, D)
        n_other = n - n_vec - n_tok
        with open(out_path.with_suffix(".pkl.meta"), "w") as mf:
            mf.write(
                f"entries={n}\npooled_1d={n_vec}\nper_token_2d={n_tok}\nother={n_other}\n"
            )
    except Exception:
        pass

    if save_as_shelf:
        shelf_path = str(pkl_path).replace(".pkl", ".shelf")
        pkl_to_shelf(pkl_path=pkl_path, shelf_path=shelf_path)
        logger.info(f"Saved as shelf at {shelf_path}")
    return pkl_path