svincoff's picture
training works
29899b4
"""
Utility funcitons related to creating embeddings
"""
import numpy as np
from pathlib import Path
import pickle
import shelve
import rootutils
from dpacman.utils import pylogger
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
def pkl_to_shelf(pkl_path: str, shelf_path: str):
# WARNING: this will load the original pickle once.
with open(pkl_path, "rb") as f:
data = pickle.load(f) # {sequence_str: np.ndarray or list}
with shelve.open(shelf_path, flag="n", writeback=False) as db:
for k, v in data.items():
arr = np.asarray(v) # ensure ndarray
db[str(k)] = arr
def pad_token_embeddings(list_of_arrays, pad_value=0.0):
"""
list_of_arrays: list of (L_i, D) numpy arrays
Returns:
padded: (N, L_max, D) array
mask: (N, L_max) boolean array where True = real token, False = padding
"""
N = len(list_of_arrays)
D = list_of_arrays[0].shape[1]
L_max = max(arr.shape[0] for arr in list_of_arrays)
padded = np.full((N, L_max, D), pad_value, dtype=list_of_arrays[0].dtype)
mask = np.zeros((N, L_max), dtype=bool)
for i, arr in enumerate(list_of_arrays):
L = arr.shape[0]
padded[i, :L] = arr
mask[i, :L] = True
return padded, mask
def _to_numpy(x):
"""Best-effort: convert torch.Tensor or arraylikes to np.ndarray (CPU)."""
try:
import torch
if isinstance(x, torch.Tensor):
return x.detach().cpu().numpy()
except Exception:
pass
return np.asarray(x)
def embed_and_save(
seqs, ids, embedder, out_path, batch_size=1, save_as_shelf: bool = True
):
"""
Using the passed embedder, make embeddings and store as a pickle mapping:
{sequence (str): embedding (np.ndarray)}
Notes:
- If multiple entries share the exact same sequence string, the *last one wins* in the embedder.
- Validates that every requested sequence has an embedding.
"""
out_path = Path(out_path)
pkl_path = out_path.with_suffix(".pkl")
# 1) Run the embedder (expects dict: {seq: embedding})
embs_dict = embedder.embed(seqs, batch_size=batch_size)
if not isinstance(embs_dict, dict):
raise TypeError(f"Expected dict from embedder.embed, got {type(embs_dict)}")
# 2) Detect duplicates in the input order
seen, dupes = set(), 0
for s in seqs:
if s in seen:
dupes += 1
seen.add(s)
if dupes:
msg = (
f"[embed_and_save] Warning: {dupes} duplicate sequence(s) in input; "
f"pickle will contain one entry per unique sequence."
)
try:
(logger.info if logger is not None else print)(msg)
except Exception:
print(msg, flush=True)
# 3) Build ordered mapping (respect input order; last occurrence already reflected in embs_dict)
mapping = {}
missing = []
for s in seqs:
if s in mapping:
continue # already stored (keep one per unique sequence)
e = embs_dict.get(s)
if e is None:
missing.append(s)
continue
mapping[s] = _to_numpy(e)
if missing:
raise KeyError(
f"Embedder did not return embeddings for {len(missing)} sequence(s). "
f"Example: {missing[0][:50]}..."
)
# 4) Save pickle
with open(pkl_path, "wb") as f:
pickle.dump(mapping, f, protocol=5)
logger.info(f"Saved as pkl at {pkl_path}")
# 5) Optional tiny manifest
try:
n = len(mapping)
ndims = [v.ndim for v in mapping.values()]
n_vec = sum(d == 1 for d in ndims) # pooled (D,)
n_tok = sum(d == 2 for d in ndims) # per-token (L, D)
n_other = n - n_vec - n_tok
with open(out_path.with_suffix(".pkl.meta"), "w") as mf:
mf.write(
f"entries={n}\npooled_1d={n_vec}\nper_token_2d={n_tok}\nother={n_other}\n"
)
except Exception:
pass
if save_as_shelf:
shelf_path = str(pkl_path).replace(".pkl", ".shelf")
pkl_to_shelf(pkl_path=pkl_path, shelf_path=shelf_path)
logger.info(f"Saved as shelf at {shelf_path}")
return pkl_path