"""Subset-construction methods: mosaic and gradient. Both methods take a per-sequence score (typically `contrast_hvlv`) and produce N_SUBSETS lists of pool indices of length TARGET_SIZE. Defaults match the published SF-Cluster Phase XII protocol: N_SUBSETS = 12 TARGET_SIZE = 32 mosaic seeds: s = 0, 1, ..., N_SUBSETS-1 gradient seeds: bin_i * 10 + s for s in {0, 1, 2}, bin_i in {0..3} """ from __future__ import annotations from pathlib import Path from typing import List, Optional, Sequence import numpy as np from .pool import Pool, pool_msa, write_a3m from .score import contrast_hvlv N_SUBSETS = 12 TARGET_SIZE = 32 def _subsample(indices: Sequence[int], size: int, rng: np.random.Generator) -> List[int]: """Sample `size` items from `indices` without replacement if possible, with replacement otherwise. Empty input returns [].""" idx = list(indices) if len(idx) == 0: return [] if len(idx) >= size: return list(rng.choice(idx, size=size, replace=False)) return list(rng.choice(idx, size=size, replace=True)) # --------------------------------------------------------------------------- # Method: mosaic # --------------------------------------------------------------------------- def method_mosaic(score: np.ndarray, n_subsets: int = N_SUBSETS, subset_size: int = TARGET_SIZE, *, high_n: int = 11, low_n: int = 11, mid_n: int = 10) -> List[List[int]]: """Tri-stratified mosaic: each subset mixes high/low/mid score tiers. Pool is tri-stratified on `score` (low / mid / high terciles), and each of `n_subsets` subsets samples (high_n + low_n + mid_n) = subset_size items. Seeds: subset s uses np.random.default_rng(seed=s). Args: score: (N,) per-pool-sequence score (e.g., contrast_hvlv). n_subsets: number of subsets to build (default 12). subset_size: total seqs per subset; must equal high_n+low_n+mid_n. high_n, low_n, mid_n: per-tier sample counts (defaults 11/11/10). Returns: list of n_subsets lists of pool indices, length == subset_size each. """ if high_n + low_n + mid_n != subset_size: raise ValueError( f"high_n+low_n+mid_n ({high_n+low_n+mid_n}) != subset_size ({subset_size})" ) score = np.asarray(score) if score.ndim != 1: raise ValueError("score must be 1-D") N = score.shape[0] if N == 0: raise ValueError("empty score array") sorted_idx = np.argsort(score) low_group = list(sorted_idx[: N // 3]) high_group = list(sorted_idx[2 * N // 3 :]) mid_group = list(sorted_idx[N // 3 : 2 * N // 3]) subsets: List[List[int]] = [] for s in range(n_subsets): rng = np.random.default_rng(seed=s) hi = _subsample(high_group, high_n, rng) lo = _subsample(low_group, low_n, rng) mid = _subsample(mid_group, mid_n, rng) subsets.append([int(x) for x in (hi + lo + mid)]) return subsets # --------------------------------------------------------------------------- # Method: gradient # --------------------------------------------------------------------------- def method_gradient(score: np.ndarray, n_subsets: int = N_SUBSETS, subset_size: int = TARGET_SIZE, *, n_bins: int = 4, subsets_per_bin: int = 3) -> List[List[int]]: """Homogeneous per-quartile subsets along the `score` gradient. Pool is split into `n_bins` equal-size bins on sorted score, then for each bin `subsets_per_bin` subsets are drawn entirely from within that bin. Default 4 bins × 3 subsets-per-bin = 12 subsets. Seeds: bin_i in [0..n_bins-1], s in [0..subsets_per_bin-1] use np.random.default_rng(seed=bin_i*10 + s). Args: score: (N,) per-pool-sequence score. n_subsets: expected total (must == n_bins * subsets_per_bin). subset_size: seqs per subset. n_bins: number of score quantile bins (default 4). subsets_per_bin: subsets drawn per bin (default 3). Returns: list of n_subsets lists of pool indices. """ if n_bins * subsets_per_bin != n_subsets: raise ValueError( f"n_bins*subsets_per_bin ({n_bins*subsets_per_bin}) != n_subsets ({n_subsets})" ) score = np.asarray(score) if score.ndim != 1: raise ValueError("score must be 1-D") N = score.shape[0] if N == 0: raise ValueError("empty score array") sorted_idx = np.argsort(score) # Equal-quantile bins by integer split (matches reference impl for n_bins=4). bins: List[List[int]] = [] for b in range(n_bins): start = (b * N) // n_bins end = ((b + 1) * N) // n_bins bins.append(list(sorted_idx[start:end])) subsets: List[List[int]] = [] for bin_i, bin_idx in enumerate(bins): for s in range(subsets_per_bin): rng = np.random.default_rng(seed=bin_i * 10 + s) chosen = _subsample(bin_idx, subset_size, rng) subsets.append([int(x) for x in chosen]) return subsets # --------------------------------------------------------------------------- # High-level convenience: build_subsets # --------------------------------------------------------------------------- def _write_subset_a3ms(pool: Pool, subsets: List[List[int]], out_dir: Path, method: str, query_index: int = 0) -> List[Path]: """Write one A3M per subset; query (pool[query_index]) is always first.""" out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) q_header = pool.headers[query_index] q_seq = pool.sequences[query_index] paths: List[Path] = [] for s_i, idx_list in enumerate(subsets): seen = {q_header} seqs_for_file = [(q_header, q_seq)] for i in idx_list: h = pool.headers[i] if h in seen: continue seen.add(h) seqs_for_file.append((h, pool.sequences[i])) fname = out_dir / f"{method}_subset_{s_i:03d}.a3m" write_a3m(fname, pool.header_line, seqs_for_file) paths.append(fname) return paths def build_subsets(a3m_path: str | Path, fi_npy_path: str | Path, method: str = "mosaic", *, n_subsets: int = N_SUBSETS, subset_size: int = TARGET_SIZE, hv_percentile: float = 80.0, out_dir: Optional[str | Path] = None, query_index: int = 0): """End-to-end: pool -> score -> subset indices [-> A3M files]. Args: a3m_path: input filtered A3M. fi_npy_path: per-residue FI matrix (N_seq, L) .npy. method: "mosaic" or "gradient". n_subsets: default 12. subset_size: default 32. hv_percentile: HV-column variance percentile for contrast_hvlv. out_dir: if given, write one A3M per subset there. query_index: which pool row is the query seq (placed first). Returns: (pool, score, subsets) or (pool, score, subsets, paths) if out_dir. """ pool = pool_msa(a3m_path, fi_npy_path) score = contrast_hvlv(pool.fi_matrix, percentile=hv_percentile) if method == "mosaic": subsets = method_mosaic(score, n_subsets=n_subsets, subset_size=subset_size) elif method == "gradient": subsets = method_gradient(score, n_subsets=n_subsets, subset_size=subset_size) else: raise ValueError(f"unknown method: {method!r} (expected 'mosaic' or 'gradient')") if out_dir is None: return pool, score, subsets paths = _write_subset_a3ms(pool, subsets, Path(out_dir), method, query_index=query_index) return pool, score, subsets, paths