File size: 3,302 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import argparse
import multiprocessing
from dataclasses import asdict
from functools import partial
from pathlib import Path
from typing import Any

import numpy as np
from p_tqdm import p_umap
from redis import Redis
from tqdm import tqdm

from boltz.data.parse.a3m import parse_a3m


class Resource:
    """A shared resource for processing."""

    def __init__(self, host: str, port: int) -> None:
        """Initialize the redis database."""
        self._redis = Redis(host=host, port=port)

    def get(self, key: str) -> Any:  # noqa: ANN401
        """Get an item from the Redis database."""
        return self._redis.get(key)

    def __getitem__(self, key: str) -> Any:  # noqa: ANN401
        """Get an item from the resource."""
        out = self.get(key)
        if out is None:
            raise KeyError(key)
        return out


def process_msa(
    path: Path,
    outdir: str,
    max_seqs: int,
    resource: Resource,
) -> None:
    """Run processing in a worker thread."""
    outdir = Path(outdir)
    out_path = outdir / f"{path.stem}.npz"
    if not out_path.exists():
        msa = parse_a3m(path, resource, max_seqs)
        np.savez_compressed(out_path, **asdict(msa))


def process(args) -> None:
    """Run the data processing task."""
    # Create output directory
    args.outdir.mkdir(parents=True, exist_ok=True)

    # Load the resource
    resource = Resource(host=args.redis_host, port=args.redis_port)

    # Get data points
    print("Fetching data...")
    data = list(args.msadir.rglob("*.a3m*"))
    print(f"Found {len(data)} MSA's.")

    # Check if we can run in parallel
    max_processes = multiprocessing.cpu_count()
    num_processes = max(1, min(args.num_processes, max_processes, len(data)))
    parallel = num_processes > 1

    # Run processing
    if parallel:
        # Create processing function
        fn = partial(
            process_msa,
            outdir=args.outdir,
            max_seqs=args.max_seqs,
            resource=resource,
        )

        # Run in parallel
        p_umap(fn, data, num_cpus=num_processes)

    else:
        # Run in serial
        for path in tqdm(data):
            process_msa(
                path,
                outdir=args.outdir,
                max_seqs=args.max_seqs,
                resource=resource,
            )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process MSA data.")
    parser.add_argument(
        "--msadir",
        type=Path,
        required=True,
        help="The MSA data directory.",
    )
    parser.add_argument(
        "--outdir",
        type=Path,
        default="data",
        help="The output directory.",
    )
    parser.add_argument(
        "--num-processes",
        type=int,
        default=multiprocessing.cpu_count(),
        help="The number of processes.",
    )
    parser.add_argument(
        "--redis-host",
        type=str,
        default="localhost",
        help="The Redis host.",
    )
    parser.add_argument(
        "--redis-port",
        type=int,
        default=7777,
        help="The Redis port.",
    )
    parser.add_argument(
        "--max-seqs",
        type=int,
        default=16384,
        help="The maximum number of sequences.",
    )
    args = parser.parse_args()
    process(args)