""" Utility funcitons related to creating embeddings """ import numpy as np from pathlib import Path import pickle import shelve import rootutils from dpacman.utils import pylogger root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) logger = pylogger.RankedLogger(__name__, rank_zero_only=True) def pkl_to_shelf(pkl_path: str, shelf_path: str): # WARNING: this will load the original pickle once. with open(pkl_path, "rb") as f: data = pickle.load(f) # {sequence_str: np.ndarray or list} with shelve.open(shelf_path, flag="n", writeback=False) as db: for k, v in data.items(): arr = np.asarray(v) # ensure ndarray db[str(k)] = arr def pad_token_embeddings(list_of_arrays, pad_value=0.0): """ list_of_arrays: list of (L_i, D) numpy arrays Returns: padded: (N, L_max, D) array mask: (N, L_max) boolean array where True = real token, False = padding """ N = len(list_of_arrays) D = list_of_arrays[0].shape[1] L_max = max(arr.shape[0] for arr in list_of_arrays) padded = np.full((N, L_max, D), pad_value, dtype=list_of_arrays[0].dtype) mask = np.zeros((N, L_max), dtype=bool) for i, arr in enumerate(list_of_arrays): L = arr.shape[0] padded[i, :L] = arr mask[i, :L] = True return padded, mask def _to_numpy(x): """Best-effort: convert torch.Tensor or arraylikes to np.ndarray (CPU).""" try: import torch if isinstance(x, torch.Tensor): return x.detach().cpu().numpy() except Exception: pass return np.asarray(x) def embed_and_save( seqs, ids, embedder, out_path, batch_size=1, save_as_shelf: bool = True ): """ Using the passed embedder, make embeddings and store as a pickle mapping: {sequence (str): embedding (np.ndarray)} Notes: - If multiple entries share the exact same sequence string, the *last one wins* in the embedder. - Validates that every requested sequence has an embedding. """ out_path = Path(out_path) pkl_path = out_path.with_suffix(".pkl") # 1) Run the embedder (expects dict: {seq: embedding}) embs_dict = embedder.embed(seqs, batch_size=batch_size) if not isinstance(embs_dict, dict): raise TypeError(f"Expected dict from embedder.embed, got {type(embs_dict)}") # 2) Detect duplicates in the input order seen, dupes = set(), 0 for s in seqs: if s in seen: dupes += 1 seen.add(s) if dupes: msg = ( f"[embed_and_save] Warning: {dupes} duplicate sequence(s) in input; " f"pickle will contain one entry per unique sequence." ) try: (logger.info if logger is not None else print)(msg) except Exception: print(msg, flush=True) # 3) Build ordered mapping (respect input order; last occurrence already reflected in embs_dict) mapping = {} missing = [] for s in seqs: if s in mapping: continue # already stored (keep one per unique sequence) e = embs_dict.get(s) if e is None: missing.append(s) continue mapping[s] = _to_numpy(e) if missing: raise KeyError( f"Embedder did not return embeddings for {len(missing)} sequence(s). " f"Example: {missing[0][:50]}..." ) # 4) Save pickle with open(pkl_path, "wb") as f: pickle.dump(mapping, f, protocol=5) logger.info(f"Saved as pkl at {pkl_path}") # 5) Optional tiny manifest try: n = len(mapping) ndims = [v.ndim for v in mapping.values()] n_vec = sum(d == 1 for d in ndims) # pooled (D,) n_tok = sum(d == 2 for d in ndims) # per-token (L, D) n_other = n - n_vec - n_tok with open(out_path.with_suffix(".pkl.meta"), "w") as mf: mf.write( f"entries={n}\npooled_1d={n_vec}\nper_token_2d={n_tok}\nother={n_other}\n" ) except Exception: pass if save_as_shelf: shelf_path = str(pkl_path).replace(".pkl", ".shelf") pkl_to_shelf(pkl_path=pkl_path, shelf_path=shelf_path) logger.info(f"Saved as shelf at {shelf_path}") return pkl_path