File size: 8,085 Bytes
ccbe063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
"""Subset-construction methods: mosaic and gradient.

Both methods take a per-sequence score (typically `contrast_hvlv`) and
produce N_SUBSETS lists of pool indices of length TARGET_SIZE.

Defaults match the published SF-Cluster Phase XII protocol:
    N_SUBSETS = 12
    TARGET_SIZE = 32
    mosaic seeds:   s = 0, 1, ..., N_SUBSETS-1
    gradient seeds: bin_i * 10 + s   for s in {0, 1, 2}, bin_i in {0..3}
"""
from __future__ import annotations

from pathlib import Path
from typing import List, Optional, Sequence

import numpy as np

from .pool import Pool, pool_msa, write_a3m
from .score import contrast_hvlv

N_SUBSETS = 12
TARGET_SIZE = 32


def _subsample(indices: Sequence[int], size: int, rng: np.random.Generator) -> List[int]:
    """Sample `size` items from `indices` without replacement if possible,
    with replacement otherwise. Empty input returns []."""
    idx = list(indices)
    if len(idx) == 0:
        return []
    if len(idx) >= size:
        return list(rng.choice(idx, size=size, replace=False))
    return list(rng.choice(idx, size=size, replace=True))


# ---------------------------------------------------------------------------
# Method: mosaic
# ---------------------------------------------------------------------------

def method_mosaic(score: np.ndarray,
                  n_subsets: int = N_SUBSETS,
                  subset_size: int = TARGET_SIZE,
                  *,
                  high_n: int = 11,
                  low_n: int = 11,
                  mid_n: int = 10) -> List[List[int]]:
    """Tri-stratified mosaic: each subset mixes high/low/mid score tiers.

    Pool is tri-stratified on `score` (low / mid / high terciles), and each of
    `n_subsets` subsets samples (high_n + low_n + mid_n) = subset_size items.

    Seeds: subset s uses np.random.default_rng(seed=s).

    Args:
        score:       (N,) per-pool-sequence score (e.g., contrast_hvlv).
        n_subsets:   number of subsets to build (default 12).
        subset_size: total seqs per subset; must equal high_n+low_n+mid_n.
        high_n, low_n, mid_n: per-tier sample counts (defaults 11/11/10).

    Returns:
        list of n_subsets lists of pool indices, length == subset_size each.
    """
    if high_n + low_n + mid_n != subset_size:
        raise ValueError(
            f"high_n+low_n+mid_n ({high_n+low_n+mid_n}) != subset_size ({subset_size})"
        )
    score = np.asarray(score)
    if score.ndim != 1:
        raise ValueError("score must be 1-D")
    N = score.shape[0]
    if N == 0:
        raise ValueError("empty score array")

    sorted_idx = np.argsort(score)
    low_group  = list(sorted_idx[: N // 3])
    high_group = list(sorted_idx[2 * N // 3 :])
    mid_group  = list(sorted_idx[N // 3 : 2 * N // 3])

    subsets: List[List[int]] = []
    for s in range(n_subsets):
        rng = np.random.default_rng(seed=s)
        hi  = _subsample(high_group, high_n, rng)
        lo  = _subsample(low_group,  low_n,  rng)
        mid = _subsample(mid_group,  mid_n,  rng)
        subsets.append([int(x) for x in (hi + lo + mid)])
    return subsets


# ---------------------------------------------------------------------------
# Method: gradient
# ---------------------------------------------------------------------------

def method_gradient(score: np.ndarray,
                    n_subsets: int = N_SUBSETS,
                    subset_size: int = TARGET_SIZE,
                    *,
                    n_bins: int = 4,
                    subsets_per_bin: int = 3) -> List[List[int]]:
    """Homogeneous per-quartile subsets along the `score` gradient.

    Pool is split into `n_bins` equal-size bins on sorted score, then for each
    bin `subsets_per_bin` subsets are drawn entirely from within that bin.

    Default 4 bins × 3 subsets-per-bin = 12 subsets.

    Seeds: bin_i in [0..n_bins-1], s in [0..subsets_per_bin-1] use
    np.random.default_rng(seed=bin_i*10 + s).

    Args:
        score:           (N,) per-pool-sequence score.
        n_subsets:       expected total (must == n_bins * subsets_per_bin).
        subset_size:     seqs per subset.
        n_bins:          number of score quantile bins (default 4).
        subsets_per_bin: subsets drawn per bin (default 3).

    Returns:
        list of n_subsets lists of pool indices.
    """
    if n_bins * subsets_per_bin != n_subsets:
        raise ValueError(
            f"n_bins*subsets_per_bin ({n_bins*subsets_per_bin}) != n_subsets ({n_subsets})"
        )
    score = np.asarray(score)
    if score.ndim != 1:
        raise ValueError("score must be 1-D")
    N = score.shape[0]
    if N == 0:
        raise ValueError("empty score array")

    sorted_idx = np.argsort(score)
    # Equal-quantile bins by integer split (matches reference impl for n_bins=4).
    bins: List[List[int]] = []
    for b in range(n_bins):
        start = (b * N) // n_bins
        end = ((b + 1) * N) // n_bins
        bins.append(list(sorted_idx[start:end]))

    subsets: List[List[int]] = []
    for bin_i, bin_idx in enumerate(bins):
        for s in range(subsets_per_bin):
            rng = np.random.default_rng(seed=bin_i * 10 + s)
            chosen = _subsample(bin_idx, subset_size, rng)
            subsets.append([int(x) for x in chosen])
    return subsets


# ---------------------------------------------------------------------------
# High-level convenience: build_subsets
# ---------------------------------------------------------------------------

def _write_subset_a3ms(pool: Pool,
                       subsets: List[List[int]],
                       out_dir: Path,
                       method: str,
                       query_index: int = 0) -> List[Path]:
    """Write one A3M per subset; query (pool[query_index]) is always first."""
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    q_header = pool.headers[query_index]
    q_seq = pool.sequences[query_index]
    paths: List[Path] = []
    for s_i, idx_list in enumerate(subsets):
        seen = {q_header}
        seqs_for_file = [(q_header, q_seq)]
        for i in idx_list:
            h = pool.headers[i]
            if h in seen:
                continue
            seen.add(h)
            seqs_for_file.append((h, pool.sequences[i]))
        fname = out_dir / f"{method}_subset_{s_i:03d}.a3m"
        write_a3m(fname, pool.header_line, seqs_for_file)
        paths.append(fname)
    return paths


def build_subsets(a3m_path: str | Path,
                  fi_npy_path: str | Path,
                  method: str = "mosaic",
                  *,
                  n_subsets: int = N_SUBSETS,
                  subset_size: int = TARGET_SIZE,
                  hv_percentile: float = 80.0,
                  out_dir: Optional[str | Path] = None,
                  query_index: int = 0):
    """End-to-end: pool -> score -> subset indices [-> A3M files].

    Args:
        a3m_path:     input filtered A3M.
        fi_npy_path:  per-residue FI matrix (N_seq, L) .npy.
        method:       "mosaic" or "gradient".
        n_subsets:    default 12.
        subset_size:  default 32.
        hv_percentile: HV-column variance percentile for contrast_hvlv.
        out_dir:      if given, write one A3M per subset there.
        query_index:  which pool row is the query seq (placed first).

    Returns:
        (pool, score, subsets) or (pool, score, subsets, paths) if out_dir.
    """
    pool = pool_msa(a3m_path, fi_npy_path)
    score = contrast_hvlv(pool.fi_matrix, percentile=hv_percentile)

    if method == "mosaic":
        subsets = method_mosaic(score, n_subsets=n_subsets, subset_size=subset_size)
    elif method == "gradient":
        subsets = method_gradient(score, n_subsets=n_subsets, subset_size=subset_size)
    else:
        raise ValueError(f"unknown method: {method!r} (expected 'mosaic' or 'gradient')")

    if out_dir is None:
        return pool, score, subsets

    paths = _write_subset_a3ms(pool, subsets, Path(out_dir), method,
                               query_index=query_index)
    return pool, score, subsets, paths