File size: 3,302 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import argparse
import multiprocessing
from dataclasses import asdict
from functools import partial
from pathlib import Path
from typing import Any
import numpy as np
from p_tqdm import p_umap
from redis import Redis
from tqdm import tqdm
from boltz.data.parse.a3m import parse_a3m
class Resource:
"""A shared resource for processing."""
def __init__(self, host: str, port: int) -> None:
"""Initialize the redis database."""
self._redis = Redis(host=host, port=port)
def get(self, key: str) -> Any: # noqa: ANN401
"""Get an item from the Redis database."""
return self._redis.get(key)
def __getitem__(self, key: str) -> Any: # noqa: ANN401
"""Get an item from the resource."""
out = self.get(key)
if out is None:
raise KeyError(key)
return out
def process_msa(
path: Path,
outdir: str,
max_seqs: int,
resource: Resource,
) -> None:
"""Run processing in a worker thread."""
outdir = Path(outdir)
out_path = outdir / f"{path.stem}.npz"
if not out_path.exists():
msa = parse_a3m(path, resource, max_seqs)
np.savez_compressed(out_path, **asdict(msa))
def process(args) -> None:
"""Run the data processing task."""
# Create output directory
args.outdir.mkdir(parents=True, exist_ok=True)
# Load the resource
resource = Resource(host=args.redis_host, port=args.redis_port)
# Get data points
print("Fetching data...")
data = list(args.msadir.rglob("*.a3m*"))
print(f"Found {len(data)} MSA's.")
# Check if we can run in parallel
max_processes = multiprocessing.cpu_count()
num_processes = max(1, min(args.num_processes, max_processes, len(data)))
parallel = num_processes > 1
# Run processing
if parallel:
# Create processing function
fn = partial(
process_msa,
outdir=args.outdir,
max_seqs=args.max_seqs,
resource=resource,
)
# Run in parallel
p_umap(fn, data, num_cpus=num_processes)
else:
# Run in serial
for path in tqdm(data):
process_msa(
path,
outdir=args.outdir,
max_seqs=args.max_seqs,
resource=resource,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process MSA data.")
parser.add_argument(
"--msadir",
type=Path,
required=True,
help="The MSA data directory.",
)
parser.add_argument(
"--outdir",
type=Path,
default="data",
help="The output directory.",
)
parser.add_argument(
"--num-processes",
type=int,
default=multiprocessing.cpu_count(),
help="The number of processes.",
)
parser.add_argument(
"--redis-host",
type=str,
default="localhost",
help="The Redis host.",
)
parser.add_argument(
"--redis-port",
type=int,
default=7777,
help="The Redis port.",
)
parser.add_argument(
"--max-seqs",
type=int,
default=16384,
help="The maximum number of sequences.",
)
args = parser.parse_args()
process(args)
|