| | import argparse |
| | import multiprocessing |
| | from dataclasses import asdict |
| | from functools import partial |
| | from pathlib import Path |
| | from typing import Any |
| |
|
| | import numpy as np |
| | from p_tqdm import p_umap |
| | from redis import Redis |
| | from tqdm import tqdm |
| |
|
| | from boltz.data.parse.a3m import parse_a3m |
| |
|
| |
|
| | class Resource: |
| | """A shared resource for processing.""" |
| |
|
| | def __init__(self, host: str, port: int) -> None: |
| | """Initialize the redis database.""" |
| | self._redis = Redis(host=host, port=port) |
| |
|
| | def get(self, key: str) -> Any: |
| | """Get an item from the Redis database.""" |
| | return self._redis.get(key) |
| |
|
| | def __getitem__(self, key: str) -> Any: |
| | """Get an item from the resource.""" |
| | out = self.get(key) |
| | if out is None: |
| | raise KeyError(key) |
| | return out |
| |
|
| |
|
| | def process_msa( |
| | path: Path, |
| | outdir: str, |
| | max_seqs: int, |
| | resource: Resource, |
| | ) -> None: |
| | """Run processing in a worker thread.""" |
| | outdir = Path(outdir) |
| | out_path = outdir / f"{path.stem}.npz" |
| | if not out_path.exists(): |
| | msa = parse_a3m(path, resource, max_seqs) |
| | np.savez_compressed(out_path, **asdict(msa)) |
| |
|
| |
|
| | def process(args) -> None: |
| | """Run the data processing task.""" |
| | |
| | args.outdir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | resource = Resource(host=args.redis_host, port=args.redis_port) |
| |
|
| | |
| | print("Fetching data...") |
| | data = list(args.msadir.rglob("*.a3m*")) |
| | print(f"Found {len(data)} MSA's.") |
| |
|
| | |
| | max_processes = multiprocessing.cpu_count() |
| | num_processes = max(1, min(args.num_processes, max_processes, len(data))) |
| | parallel = num_processes > 1 |
| |
|
| | |
| | if parallel: |
| | |
| | fn = partial( |
| | process_msa, |
| | outdir=args.outdir, |
| | max_seqs=args.max_seqs, |
| | resource=resource, |
| | ) |
| |
|
| | |
| | p_umap(fn, data, num_cpus=num_processes) |
| |
|
| | else: |
| | |
| | for path in tqdm(data): |
| | process_msa( |
| | path, |
| | outdir=args.outdir, |
| | max_seqs=args.max_seqs, |
| | resource=resource, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Process MSA data.") |
| | parser.add_argument( |
| | "--msadir", |
| | type=Path, |
| | required=True, |
| | help="The MSA data directory.", |
| | ) |
| | parser.add_argument( |
| | "--outdir", |
| | type=Path, |
| | default="data", |
| | help="The output directory.", |
| | ) |
| | parser.add_argument( |
| | "--num-processes", |
| | type=int, |
| | default=multiprocessing.cpu_count(), |
| | help="The number of processes.", |
| | ) |
| | parser.add_argument( |
| | "--redis-host", |
| | type=str, |
| | default="localhost", |
| | help="The Redis host.", |
| | ) |
| | parser.add_argument( |
| | "--redis-port", |
| | type=int, |
| | default=7777, |
| | help="The Redis port.", |
| | ) |
| | parser.add_argument( |
| | "--max-seqs", |
| | type=int, |
| | default=16384, |
| | help="The maximum number of sequences.", |
| | ) |
| | args = parser.parse_args() |
| | process(args) |
| |
|