| """A3M parsing and pool construction. |
| |
| The pool ties together aligned sequences from a ColabFold-style A3M and a |
| per-residue Frustration Index (FI) matrix produced by FrustrAI-Seq. |
| |
| A3M conventions (ColabFold): |
| Line 1: optional header line beginning with '#', e.g. "#91\\t1" |
| Then alternating ">header" and sequence lines. |
| Sequence lines may contain UPPERCASE match-state letters, '-' gaps, and |
| lowercase letters denoting insertion states (not part of the alignment). |
| """ |
| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import List, Optional, Tuple |
|
|
| import numpy as np |
|
|
|
|
| @dataclass |
| class Pool: |
| """Container for sequences + per-residue FI vectors. |
| |
| Attributes: |
| headers: list[str] short header (first whitespace-separated token) |
| sequences: list[str] aligned sequences (lowercase insertion states preserved) |
| fi_matrix: np.ndarray (N, L) per-residue FI; columns correspond to |
| match-state (uppercase) positions in the aligned sequences |
| header_line: Optional[str] original '#' header line, if present |
| """ |
| headers: List[str] |
| sequences: List[str] |
| fi_matrix: np.ndarray |
| header_line: Optional[str] = None |
| full_headers: List[str] = field(default_factory=list) |
|
|
| def __len__(self) -> int: |
| return len(self.headers) |
|
|
| @property |
| def n_seq(self) -> int: |
| return len(self.headers) |
|
|
| @property |
| def n_cols(self) -> int: |
| return int(self.fi_matrix.shape[1]) if self.fi_matrix.size else 0 |
|
|
|
|
| |
| |
| |
|
|
| def read_a3m(path: str | Path) -> Tuple[Optional[str], List[Tuple[str, str]]]: |
| """Read an A3M file. |
| |
| Returns: |
| (header_line, [(header, seq), ...]) |
| header_line is the leading '#...' line if present, else None. |
| header is the full header text without the leading '>'. |
| seq is the raw sequence line (lowercase insertion states retained). |
| """ |
| path = Path(path) |
| with open(path) as f: |
| lines = [ln.rstrip("\n") for ln in f.readlines()] |
|
|
| if not lines: |
| return None, [] |
|
|
| i = 0 |
| header_line = None |
| if lines[0].startswith("#"): |
| header_line = lines[0] |
| i = 1 |
|
|
| seqs: List[Tuple[str, str]] = [] |
| while i < len(lines): |
| ln = lines[i] |
| if ln.startswith(">"): |
| h = ln[1:] |
| s = lines[i + 1] if i + 1 < len(lines) else "" |
| seqs.append((h, s)) |
| i += 2 |
| else: |
| i += 1 |
| return header_line, seqs |
|
|
|
|
| def write_a3m(path: str | Path, |
| header_line: Optional[str], |
| seqs: List[Tuple[str, str]]) -> None: |
| """Write an A3M file. seqs = [(header, seq), ...].""" |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| with open(path, "w") as f: |
| if header_line is not None: |
| f.write(header_line + "\n") |
| for h, s in seqs: |
| f.write(f">{h}\n{s}\n") |
|
|
|
|
| |
| |
| |
|
|
| def _dedup_a3m(seqs: List[Tuple[str, str]]) -> Tuple[List[int], List[Tuple[str, str]]]: |
| """Deduplicate by short header (first whitespace token). |
| |
| Returns (kept_indices_into_input, [(short_header, seq), ...]). |
| """ |
| seen = set() |
| keep_idx: List[int] = [] |
| out: List[Tuple[str, str]] = [] |
| for i, (h, s) in enumerate(seqs): |
| short = h.split()[0] |
| if short in seen: |
| continue |
| seen.add(short) |
| keep_idx.append(i) |
| out.append((short, s)) |
| return keep_idx, out |
|
|
|
|
| def pool_msa(a3m_path: str | Path, |
| fi_npy_path: str | Path, |
| *, |
| dedup: bool = True) -> Pool: |
| """Build a Pool from an A3M file and a per-residue FI matrix. |
| |
| Args: |
| a3m_path: path to filtered.a3m (ColabFold style). |
| fi_npy_path: path to FI matrix .npy of shape (N_seq, L) where |
| N_seq matches the number of sequences in the A3M and |
| L is the number of match-state alignment columns. |
| Typically produced by FrustrAI-Seq |
| (https://github.com/leuschj/FrustrAI-Seq, |
| HF model: leuschj/FrustrAI-Seq). |
| dedup: drop duplicates by short header (default True). |
| |
| Returns: |
| Pool object. |
| |
| Raises: |
| ValueError if N_seq disagree between the A3M and the FI matrix. |
| """ |
| header_line, raw_seqs = read_a3m(a3m_path) |
| fi = np.load(str(fi_npy_path)) |
|
|
| if fi.ndim != 2: |
| raise ValueError( |
| f"FI matrix must be 2-D (N_seq, L); got shape {fi.shape}" |
| ) |
| if fi.shape[0] != len(raw_seqs): |
| raise ValueError( |
| f"FI rows ({fi.shape[0]}) != A3M sequences ({len(raw_seqs)}) " |
| f"for {a3m_path}" |
| ) |
|
|
| if dedup: |
| keep_idx, kept = _dedup_a3m(raw_seqs) |
| fi = fi[keep_idx] |
| full_headers = [raw_seqs[i][0] for i in keep_idx] |
| short_headers = [h for h, _ in kept] |
| seqs = [s for _, s in kept] |
| else: |
| full_headers = [h for h, _ in raw_seqs] |
| short_headers = [h.split()[0] for h, _ in raw_seqs] |
| seqs = [s for _, s in raw_seqs] |
|
|
| return Pool( |
| headers=short_headers, |
| sequences=seqs, |
| fi_matrix=np.asarray(fi, dtype=np.float64), |
| header_line=header_line, |
| full_headers=full_headers, |
| ) |
|
|