chq1155's picture
Initial OSS release: mosaic + gradient subset builders (verified KaiB 95.0%, GA98 92.5%, GB98 50.0% on Phase XII pilot)
ccbe063 verified
"""A3M parsing and pool construction.
The pool ties together aligned sequences from a ColabFold-style A3M and a
per-residue Frustration Index (FI) matrix produced by FrustrAI-Seq.
A3M conventions (ColabFold):
Line 1: optional header line beginning with '#', e.g. "#91\\t1"
Then alternating ">header" and sequence lines.
Sequence lines may contain UPPERCASE match-state letters, '-' gaps, and
lowercase letters denoting insertion states (not part of the alignment).
"""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
@dataclass
class Pool:
"""Container for sequences + per-residue FI vectors.
Attributes:
headers: list[str] short header (first whitespace-separated token)
sequences: list[str] aligned sequences (lowercase insertion states preserved)
fi_matrix: np.ndarray (N, L) per-residue FI; columns correspond to
match-state (uppercase) positions in the aligned sequences
header_line: Optional[str] original '#' header line, if present
"""
headers: List[str]
sequences: List[str]
fi_matrix: np.ndarray
header_line: Optional[str] = None
full_headers: List[str] = field(default_factory=list)
def __len__(self) -> int:
return len(self.headers)
@property
def n_seq(self) -> int:
return len(self.headers)
@property
def n_cols(self) -> int:
return int(self.fi_matrix.shape[1]) if self.fi_matrix.size else 0
# ---------------------------------------------------------------------------
# A3M I/O
# ---------------------------------------------------------------------------
def read_a3m(path: str | Path) -> Tuple[Optional[str], List[Tuple[str, str]]]:
"""Read an A3M file.
Returns:
(header_line, [(header, seq), ...])
header_line is the leading '#...' line if present, else None.
header is the full header text without the leading '>'.
seq is the raw sequence line (lowercase insertion states retained).
"""
path = Path(path)
with open(path) as f:
lines = [ln.rstrip("\n") for ln in f.readlines()]
if not lines:
return None, []
i = 0
header_line = None
if lines[0].startswith("#"):
header_line = lines[0]
i = 1
seqs: List[Tuple[str, str]] = []
while i < len(lines):
ln = lines[i]
if ln.startswith(">"):
h = ln[1:]
s = lines[i + 1] if i + 1 < len(lines) else ""
seqs.append((h, s))
i += 2
else:
i += 1
return header_line, seqs
def write_a3m(path: str | Path,
header_line: Optional[str],
seqs: List[Tuple[str, str]]) -> None:
"""Write an A3M file. seqs = [(header, seq), ...]."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
if header_line is not None:
f.write(header_line + "\n")
for h, s in seqs:
f.write(f">{h}\n{s}\n")
# ---------------------------------------------------------------------------
# Pool construction
# ---------------------------------------------------------------------------
def _dedup_a3m(seqs: List[Tuple[str, str]]) -> Tuple[List[int], List[Tuple[str, str]]]:
"""Deduplicate by short header (first whitespace token).
Returns (kept_indices_into_input, [(short_header, seq), ...]).
"""
seen = set()
keep_idx: List[int] = []
out: List[Tuple[str, str]] = []
for i, (h, s) in enumerate(seqs):
short = h.split()[0]
if short in seen:
continue
seen.add(short)
keep_idx.append(i)
out.append((short, s))
return keep_idx, out
def pool_msa(a3m_path: str | Path,
fi_npy_path: str | Path,
*,
dedup: bool = True) -> Pool:
"""Build a Pool from an A3M file and a per-residue FI matrix.
Args:
a3m_path: path to filtered.a3m (ColabFold style).
fi_npy_path: path to FI matrix .npy of shape (N_seq, L) where
N_seq matches the number of sequences in the A3M and
L is the number of match-state alignment columns.
Typically produced by FrustrAI-Seq
(https://github.com/leuschj/FrustrAI-Seq,
HF model: leuschj/FrustrAI-Seq).
dedup: drop duplicates by short header (default True).
Returns:
Pool object.
Raises:
ValueError if N_seq disagree between the A3M and the FI matrix.
"""
header_line, raw_seqs = read_a3m(a3m_path)
fi = np.load(str(fi_npy_path))
if fi.ndim != 2:
raise ValueError(
f"FI matrix must be 2-D (N_seq, L); got shape {fi.shape}"
)
if fi.shape[0] != len(raw_seqs):
raise ValueError(
f"FI rows ({fi.shape[0]}) != A3M sequences ({len(raw_seqs)}) "
f"for {a3m_path}"
)
if dedup:
keep_idx, kept = _dedup_a3m(raw_seqs)
fi = fi[keep_idx]
full_headers = [raw_seqs[i][0] for i in keep_idx]
short_headers = [h for h, _ in kept]
seqs = [s for _, s in kept]
else:
full_headers = [h for h, _ in raw_seqs]
short_headers = [h.split()[0] for h, _ in raw_seqs]
seqs = [s for _, s in raw_seqs]
return Pool(
headers=short_headers,
sequences=seqs,
fi_matrix=np.asarray(fi, dtype=np.float64),
header_line=header_line,
full_headers=full_headers,
)