| import os |
| import hydra |
| import lightning as L |
| import numpy as np |
| import omegaconf |
| import pandas as pd |
| import rdkit |
| import rich.syntax |
| import rich.tree |
| import torch |
| from tqdm.auto import tqdm |
| import pdb |
|
|
| import dataloader |
| import diffusion |
| from models.bindevaluator import BindEvaluator |
|
|
| rdkit.rdBase.DisableLog('rdApp.error') |
|
|
| omegaconf.OmegaConf.register_new_resolver( |
| 'cwd', os.getcwd) |
| omegaconf.OmegaConf.register_new_resolver( |
| 'device_count', torch.cuda.device_count) |
| omegaconf.OmegaConf.register_new_resolver( |
| 'eval', eval) |
| omegaconf.OmegaConf.register_new_resolver( |
| 'div_up', lambda x, y: (x + y - 1) // y) |
| omegaconf.OmegaConf.register_new_resolver( |
| 'if_then_else', |
| lambda condition, x, y: x if condition else y |
| ) |
|
|
| def _print_config( |
| config: omegaconf.DictConfig, |
| resolve: bool = True) -> None: |
| """Prints content of DictConfig using Rich library and its tree structure. |
| |
| Args: |
| config (DictConfig): Configuration composed by Hydra. |
| resolve (bool): Whether to resolve reference fields of DictConfig. |
| """ |
|
|
| style = 'dim' |
| tree = rich.tree.Tree('CONFIG', style=style, |
| guide_style=style) |
|
|
| fields = config.keys() |
| for field in fields: |
| branch = tree.add(field, style=style, guide_style=style) |
|
|
| config_section = config.get(field) |
| branch_content = str(config_section) |
| if isinstance(config_section, omegaconf.DictConfig): |
| branch_content = omegaconf.OmegaConf.to_yaml( |
| config_section, resolve=resolve) |
|
|
| branch.add(rich.syntax.Syntax(branch_content, 'yaml')) |
| rich.print(tree) |
|
|
| def parse_motif(motif: str) -> list: |
| parts = motif.split(',') |
| result = [] |
|
|
| for part in parts: |
| part = part.strip() |
| if '-' in part: |
| start, end = map(int, part.split('-')) |
| result.extend(range(start, end + 1)) |
| else: |
| result.append(int(part)) |
|
|
| return torch.tensor(result) |
|
|
| @hydra.main(version_base=None, config_path='./configs', |
| config_name='config') |
| def main(config: omegaconf.DictConfig) -> None: |
| |
| L.seed_everything(config.seed) |
| os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' |
| torch.use_deterministic_algorithms(True) |
| torch.backends.cudnn.benchmark = False |
|
|
| |
| print(f"Checkpoint: {config.eval.checkpoint_path}") |
|
|
| tokenizer = dataloader.get_tokenizer(config) |
| target_sequence = tokenizer(config.eval.target_sequence, return_tensors='pt')['input_ids'] |
| |
| pretrained = diffusion.Diffusion.load_from_checkpoint( |
| config.eval.checkpoint_path, |
| tokenizer=tokenizer, |
| config=config, logger=False) |
| pretrained.eval() |
| pretrained = pretrained.to('cuda') |
|
|
| bindevaluator = BindEvaluator.load_from_checkpoint( |
| config.guidance.classifier_checkpoint_path, |
| n_layers=8, |
| d_model=128, |
| d_hidden=128, |
| n_head=8, |
| d_k=64, |
| d_v=128, |
| d_inner=64) |
| bindevaluator = bindevaluator.to('cuda') |
|
|
| samples = [] |
| for _ in tqdm( |
| range(config.sampling.num_sample_batches), |
| desc='Gen. batches', leave=False): |
| sample = pretrained.sample( |
| target_sequence = target_sequence, |
| target_motifs = parse_motif(config.eval.target_motifs), |
| classifier_model = bindevaluator |
| ) |
| |
| samples.extend( |
| pretrained.tokenizer.batch_decode(sample)) |
|
|
| print([sample.replace(' ', '')[5:-5] for sample in samples]) |
| |
| samples = [sample.replace(' ', '')[5:-5] for sample in samples] |
| print(samples) |
|
|
| if __name__ == '__main__': |
| main() |