| """ |
| Utility funcitons related to creating embeddings |
| """ |
|
|
| import numpy as np |
| from pathlib import Path |
| import pickle |
| import shelve |
| import rootutils |
| from dpacman.utils import pylogger |
|
|
| root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) |
| logger = pylogger.RankedLogger(__name__, rank_zero_only=True) |
|
|
|
|
| def pkl_to_shelf(pkl_path: str, shelf_path: str): |
| |
| with open(pkl_path, "rb") as f: |
| data = pickle.load(f) |
| with shelve.open(shelf_path, flag="n", writeback=False) as db: |
| for k, v in data.items(): |
| arr = np.asarray(v) |
| db[str(k)] = arr |
|
|
|
|
| def pad_token_embeddings(list_of_arrays, pad_value=0.0): |
| """ |
| list_of_arrays: list of (L_i, D) numpy arrays |
| Returns: |
| padded: (N, L_max, D) array |
| mask: (N, L_max) boolean array where True = real token, False = padding |
| """ |
| N = len(list_of_arrays) |
| D = list_of_arrays[0].shape[1] |
| L_max = max(arr.shape[0] for arr in list_of_arrays) |
| padded = np.full((N, L_max, D), pad_value, dtype=list_of_arrays[0].dtype) |
| mask = np.zeros((N, L_max), dtype=bool) |
| for i, arr in enumerate(list_of_arrays): |
| L = arr.shape[0] |
| padded[i, :L] = arr |
| mask[i, :L] = True |
| return padded, mask |
|
|
|
|
| def _to_numpy(x): |
| """Best-effort: convert torch.Tensor or arraylikes to np.ndarray (CPU).""" |
| try: |
| import torch |
|
|
| if isinstance(x, torch.Tensor): |
| return x.detach().cpu().numpy() |
| except Exception: |
| pass |
| return np.asarray(x) |
|
|
|
|
| def embed_and_save( |
| seqs, ids, embedder, out_path, batch_size=1, save_as_shelf: bool = True |
| ): |
| """ |
| Using the passed embedder, make embeddings and store as a pickle mapping: |
| {sequence (str): embedding (np.ndarray)} |
| |
| Notes: |
| - If multiple entries share the exact same sequence string, the *last one wins* in the embedder. |
| - Validates that every requested sequence has an embedding. |
| """ |
| out_path = Path(out_path) |
| pkl_path = out_path.with_suffix(".pkl") |
|
|
| |
| embs_dict = embedder.embed(seqs, batch_size=batch_size) |
| if not isinstance(embs_dict, dict): |
| raise TypeError(f"Expected dict from embedder.embed, got {type(embs_dict)}") |
|
|
| |
| seen, dupes = set(), 0 |
| for s in seqs: |
| if s in seen: |
| dupes += 1 |
| seen.add(s) |
| if dupes: |
| msg = ( |
| f"[embed_and_save] Warning: {dupes} duplicate sequence(s) in input; " |
| f"pickle will contain one entry per unique sequence." |
| ) |
| try: |
| (logger.info if logger is not None else print)(msg) |
| except Exception: |
| print(msg, flush=True) |
|
|
| |
| mapping = {} |
| missing = [] |
| for s in seqs: |
| if s in mapping: |
| continue |
| e = embs_dict.get(s) |
| if e is None: |
| missing.append(s) |
| continue |
| mapping[s] = _to_numpy(e) |
|
|
| if missing: |
| raise KeyError( |
| f"Embedder did not return embeddings for {len(missing)} sequence(s). " |
| f"Example: {missing[0][:50]}..." |
| ) |
|
|
| |
| with open(pkl_path, "wb") as f: |
| pickle.dump(mapping, f, protocol=5) |
| logger.info(f"Saved as pkl at {pkl_path}") |
|
|
| |
| try: |
| n = len(mapping) |
| ndims = [v.ndim for v in mapping.values()] |
| n_vec = sum(d == 1 for d in ndims) |
| n_tok = sum(d == 2 for d in ndims) |
| n_other = n - n_vec - n_tok |
| with open(out_path.with_suffix(".pkl.meta"), "w") as mf: |
| mf.write( |
| f"entries={n}\npooled_1d={n_vec}\nper_token_2d={n_tok}\nother={n_other}\n" |
| ) |
| except Exception: |
| pass |
|
|
| if save_as_shelf: |
| shelf_path = str(pkl_path).replace(".pkl", ".shelf") |
| pkl_to_shelf(pkl_path=pkl_path, shelf_path=shelf_path) |
| logger.info(f"Saved as shelf at {shelf_path}") |
| return pkl_path |
|
|