| | """DDP inference script.""" |
| | import os |
| | import time |
| | import numpy as np |
| | import hydra |
| | import torch |
| | import GPUtil |
| | import sys |
| |
|
| | from pytorch_lightning import Trainer |
| | from omegaconf import DictConfig, OmegaConf |
| | from experiments import utils as eu |
| | from models.flow_module import FlowModule |
| | import re |
| | from typing import Optional |
| | import subprocess |
| | from biotite.sequence.io import fasta |
| | from data import utils as du |
| | from analysis import metrics |
| | import pandas as pd |
| | import esm |
| | import shutil |
| | import biotite.structure.io as bsio |
| |
|
| |
|
| | torch.set_float32_matmul_precision('high') |
| | log = eu.get_pylogger(__name__) |
| |
|
| | class Sampler: |
| |
|
| | def __init__(self, cfg: DictConfig): |
| | """Initialize sampler. |
| | |
| | Args: |
| | cfg: inference config. |
| | """ |
| | ckpt_path = cfg.inference.ckpt_path |
| | ckpt_dir = os.path.dirname(ckpt_path) |
| | |
| | |
| |
|
| | |
| | OmegaConf.set_struct(cfg, False) |
| | |
| | |
| | |
| |
|
| | self._cfg = cfg |
| | |
| | self._infer_cfg = cfg.inference |
| | self._samples_cfg = self._infer_cfg.samples |
| | |
| |
|
| | |
| | self._ckpt_name = '/'.join(ckpt_path.replace('.ckpt', '').split('/')[-3:]) |
| | self._output_dir = os.path.join( |
| | self._infer_cfg.output_dir, |
| | self._ckpt_name, |
| | self._infer_cfg.name, |
| | ) |
| | os.makedirs(self._output_dir, exist_ok=True) |
| | log.info(f'Saving results to {self._output_dir}') |
| | config_path = os.path.join(self._output_dir, 'config.yaml') |
| | with open(config_path, 'w') as f: |
| | OmegaConf.save(config=self._cfg, f=f) |
| | log.info(f'Saving inference config to {config_path}') |
| |
|
| | |
| | self._flow_module = FlowModule.load_from_checkpoint( |
| | checkpoint_path=ckpt_path, |
| | ) |
| | self._flow_module.eval() |
| | self._flow_module._infer_cfg = self._infer_cfg |
| | self._flow_module._samples_cfg = self._samples_cfg |
| | self._flow_module._output_dir = self._output_dir |
| |
|
| | |
| | |
| | |
| |
|
| | devices = [torch.cuda.current_device()] |
| |
|
| | self._folding_model = esm.pretrained.esmfold_v1().eval() |
| | self._folding_model = self._folding_model.to(devices[-1]) |
| |
|
| | def run_sampling(self): |
| | |
| | |
| | devices = [torch.cuda.current_device()] |
| |
|
| | log.info(f"Using devices: {devices}") |
| |
|
| | eval_dataset = eu.LengthDataset(self._samples_cfg) |
| | dataloader = torch.utils.data.DataLoader( |
| | eval_dataset, batch_size=self._samples_cfg.sample_batch, shuffle=False, drop_last=False) |
| | |
| | trainer = Trainer( |
| | accelerator="gpu", |
| | strategy="ddp", |
| | devices=devices, |
| | ) |
| | trainer.predict(self._flow_module, dataloaders=dataloader) |
| | |
| |
|
| |
|
| | @hydra.main(version_base=None, config_path="../configs", config_name="inference") |
| | def run(cfg: DictConfig) -> None: |
| |
|
| | |
| | log.info(f'Starting inference with {cfg.inference.num_gpus} GPUs') |
| | start_time = time.time() |
| | sampler = Sampler(cfg) |
| | sampler.run_sampling() |
| | |
| | elapsed_time = time.time() - start_time |
| | log.info(f'Finished in {elapsed_time:.2f}s') |
| |
|
| | if __name__ == '__main__': |
| | run() |
| |
|