nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import argparse
import multiprocessing
from dataclasses import asdict
from functools import partial
from pathlib import Path
from typing import Any
import numpy as np
from p_tqdm import p_umap
from redis import Redis
from tqdm import tqdm
from boltz.data.parse.a3m import parse_a3m
class Resource:
"""A shared resource for processing."""
def __init__(self, host: str, port: int) -> None:
"""Initialize the redis database."""
self._redis = Redis(host=host, port=port)
def get(self, key: str) -> Any: # noqa: ANN401
"""Get an item from the Redis database."""
return self._redis.get(key)
def __getitem__(self, key: str) -> Any: # noqa: ANN401
"""Get an item from the resource."""
out = self.get(key)
if out is None:
raise KeyError(key)
return out
def process_msa(
path: Path,
outdir: str,
max_seqs: int,
resource: Resource,
) -> None:
"""Run processing in a worker thread."""
outdir = Path(outdir)
out_path = outdir / f"{path.stem}.npz"
if not out_path.exists():
msa = parse_a3m(path, resource, max_seqs)
np.savez_compressed(out_path, **asdict(msa))
def process(args) -> None:
"""Run the data processing task."""
# Create output directory
args.outdir.mkdir(parents=True, exist_ok=True)
# Load the resource
resource = Resource(host=args.redis_host, port=args.redis_port)
# Get data points
print("Fetching data...")
data = list(args.msadir.rglob("*.a3m*"))
print(f"Found {len(data)} MSA's.")
# Check if we can run in parallel
max_processes = multiprocessing.cpu_count()
num_processes = max(1, min(args.num_processes, max_processes, len(data)))
parallel = num_processes > 1
# Run processing
if parallel:
# Create processing function
fn = partial(
process_msa,
outdir=args.outdir,
max_seqs=args.max_seqs,
resource=resource,
)
# Run in parallel
p_umap(fn, data, num_cpus=num_processes)
else:
# Run in serial
for path in tqdm(data):
process_msa(
path,
outdir=args.outdir,
max_seqs=args.max_seqs,
resource=resource,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process MSA data.")
parser.add_argument(
"--msadir",
type=Path,
required=True,
help="The MSA data directory.",
)
parser.add_argument(
"--outdir",
type=Path,
default="data",
help="The output directory.",
)
parser.add_argument(
"--num-processes",
type=int,
default=multiprocessing.cpu_count(),
help="The number of processes.",
)
parser.add_argument(
"--redis-host",
type=str,
default="localhost",
help="The Redis host.",
)
parser.add_argument(
"--redis-port",
type=int,
default=7777,
help="The Redis port.",
)
parser.add_argument(
"--max-seqs",
type=int,
default=16384,
help="The maximum number of sequences.",
)
args = parser.parse_args()
process(args)