"""A3M parsing and pool construction. The pool ties together aligned sequences from a ColabFold-style A3M and a per-residue Frustration Index (FI) matrix produced by FrustrAI-Seq. A3M conventions (ColabFold): Line 1: optional header line beginning with '#', e.g. "#91\\t1" Then alternating ">header" and sequence lines. Sequence lines may contain UPPERCASE match-state letters, '-' gaps, and lowercase letters denoting insertion states (not part of the alignment). """ from __future__ import annotations from dataclasses import dataclass, field from pathlib import Path from typing import List, Optional, Tuple import numpy as np @dataclass class Pool: """Container for sequences + per-residue FI vectors. Attributes: headers: list[str] short header (first whitespace-separated token) sequences: list[str] aligned sequences (lowercase insertion states preserved) fi_matrix: np.ndarray (N, L) per-residue FI; columns correspond to match-state (uppercase) positions in the aligned sequences header_line: Optional[str] original '#' header line, if present """ headers: List[str] sequences: List[str] fi_matrix: np.ndarray header_line: Optional[str] = None full_headers: List[str] = field(default_factory=list) def __len__(self) -> int: return len(self.headers) @property def n_seq(self) -> int: return len(self.headers) @property def n_cols(self) -> int: return int(self.fi_matrix.shape[1]) if self.fi_matrix.size else 0 # --------------------------------------------------------------------------- # A3M I/O # --------------------------------------------------------------------------- def read_a3m(path: str | Path) -> Tuple[Optional[str], List[Tuple[str, str]]]: """Read an A3M file. Returns: (header_line, [(header, seq), ...]) header_line is the leading '#...' line if present, else None. header is the full header text without the leading '>'. seq is the raw sequence line (lowercase insertion states retained). """ path = Path(path) with open(path) as f: lines = [ln.rstrip("\n") for ln in f.readlines()] if not lines: return None, [] i = 0 header_line = None if lines[0].startswith("#"): header_line = lines[0] i = 1 seqs: List[Tuple[str, str]] = [] while i < len(lines): ln = lines[i] if ln.startswith(">"): h = ln[1:] s = lines[i + 1] if i + 1 < len(lines) else "" seqs.append((h, s)) i += 2 else: i += 1 return header_line, seqs def write_a3m(path: str | Path, header_line: Optional[str], seqs: List[Tuple[str, str]]) -> None: """Write an A3M file. seqs = [(header, seq), ...].""" path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w") as f: if header_line is not None: f.write(header_line + "\n") for h, s in seqs: f.write(f">{h}\n{s}\n") # --------------------------------------------------------------------------- # Pool construction # --------------------------------------------------------------------------- def _dedup_a3m(seqs: List[Tuple[str, str]]) -> Tuple[List[int], List[Tuple[str, str]]]: """Deduplicate by short header (first whitespace token). Returns (kept_indices_into_input, [(short_header, seq), ...]). """ seen = set() keep_idx: List[int] = [] out: List[Tuple[str, str]] = [] for i, (h, s) in enumerate(seqs): short = h.split()[0] if short in seen: continue seen.add(short) keep_idx.append(i) out.append((short, s)) return keep_idx, out def pool_msa(a3m_path: str | Path, fi_npy_path: str | Path, *, dedup: bool = True) -> Pool: """Build a Pool from an A3M file and a per-residue FI matrix. Args: a3m_path: path to filtered.a3m (ColabFold style). fi_npy_path: path to FI matrix .npy of shape (N_seq, L) where N_seq matches the number of sequences in the A3M and L is the number of match-state alignment columns. Typically produced by FrustrAI-Seq (https://github.com/leuschj/FrustrAI-Seq, HF model: leuschj/FrustrAI-Seq). dedup: drop duplicates by short header (default True). Returns: Pool object. Raises: ValueError if N_seq disagree between the A3M and the FI matrix. """ header_line, raw_seqs = read_a3m(a3m_path) fi = np.load(str(fi_npy_path)) if fi.ndim != 2: raise ValueError( f"FI matrix must be 2-D (N_seq, L); got shape {fi.shape}" ) if fi.shape[0] != len(raw_seqs): raise ValueError( f"FI rows ({fi.shape[0]}) != A3M sequences ({len(raw_seqs)}) " f"for {a3m_path}" ) if dedup: keep_idx, kept = _dedup_a3m(raw_seqs) fi = fi[keep_idx] full_headers = [raw_seqs[i][0] for i in keep_idx] short_headers = [h for h, _ in kept] seqs = [s for _, s in kept] else: full_headers = [h for h, _ in raw_seqs] short_headers = [h.split()[0] for h, _ in raw_seqs] seqs = [s for _, s in raw_seqs] return Pool( headers=short_headers, sequences=seqs, fi_matrix=np.asarray(fi, dtype=np.float64), header_line=header_line, full_headers=full_headers, )