config file with esm embedding guidance
Browse files- configs/config_emb_guidance.yaml +133 -0
configs/config_emb_guidance.yaml
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
|
| 4 |
+
- /data: peptide
|
| 5 |
+
- /model: small
|
| 6 |
+
- /strategy: ddp
|
| 7 |
+
- /noise: loglinear
|
| 8 |
+
- /lr_scheduler: cosine_decay_warmup # constant_warmup
|
| 9 |
+
- /classifier_model: null
|
| 10 |
+
- /guidance: cbg
|
| 11 |
+
|
| 12 |
+
mode: ppl_eval # train / train_classifier / ppl_eval
|
| 13 |
+
diffusion: uniform # absorbing_state / uniform
|
| 14 |
+
backbone: dit # dit / dimamba / ar
|
| 15 |
+
classifier_backbone: null
|
| 16 |
+
parameterization: d3pm # subs / d3pm / ar
|
| 17 |
+
time_conditioning: True # UDLM is conditioned on time
|
| 18 |
+
subs_masking: False
|
| 19 |
+
zero_recon_loss: True # Use for UDLM
|
| 20 |
+
T: 0 # 0 (continuous time) / 1000
|
| 21 |
+
# device: "cuda"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
is_vision: False
|
| 25 |
+
seed: 42
|
| 26 |
+
|
| 27 |
+
loader:
|
| 28 |
+
global_batch_size: 512
|
| 29 |
+
eval_global_batch_size: ${.global_batch_size}
|
| 30 |
+
# Note: batch_size and eval_batch_size are **per machine**
|
| 31 |
+
batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
| 32 |
+
eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
| 33 |
+
num_workers: 0 # ${eval:"len(__import__('os').sched_getaffinity(0))"}
|
| 34 |
+
pin_memory: True
|
| 35 |
+
persistent_workers: False # True
|
| 36 |
+
|
| 37 |
+
sampling:
|
| 38 |
+
use_cache: True
|
| 39 |
+
steps: 32
|
| 40 |
+
# Note: batch_size is **per machine**
|
| 41 |
+
batch_size: 1 # ${loader.eval_batch_size}
|
| 42 |
+
num_sample_batches: 5 # Total samples: `num_gpus` * `batch_size` * `num_sample_batches`
|
| 43 |
+
use_float64: False
|
| 44 |
+
# IL2 uniprot seq
|
| 45 |
+
original_binder: "MYRMQLLSCIALSLALVTNSAPTSSSTKKTQLQLEHLLLDLQMILNGINNYKNPKLTRMLTFKFYMPKKATELKHLQCLEEELKPLEEVLNLAQSKNFHLRPRDLISNINVIVLELKGSETTFMCEYADETATIVEFLNRWITFCQSIISTLT"
|
| 46 |
+
|
| 47 |
+
eval:
|
| 48 |
+
# checkpoint_path: '/home/tc415/discrete-diffusion-guidance/outputs/peptide/2024.12.31/122818/checkpoints/best.ckpt' # Used to evaluate a checkpoint after training.
|
| 49 |
+
# Tong uploaded checkpoint to A100, accessing through it
|
| 50 |
+
checkpoint_path: '/workspace/moPPIt-v2/PeptideUDLM.ckpt'
|
| 51 |
+
|
| 52 |
+
# # IL2RG gamma chain
|
| 53 |
+
# target_sequence: 'MLKPSLPFTSLLFLQLPLLGVGLNTTILTPNGNEDTTADFFLTTMPTDSLSVSTLPLPEVQCFVFNVEYMNCTWNSSSEPQPTNLTLHYWYKNSDNDKVQKCSHYLFSEEITSGCQLQKKEIHLYQTFVVQLQDPREPRRQATQMLKLQNLVIPWAPENLTLHKLSESQLELNWNNRFLNHCLEHLVQYRTDWDHSWTEQSVDYRHKFSLPSVDGQKRYTFRVRSRFNPLCGSAQHWSEWSHPIHWGSNTSKENPFLFALEAVVISVGSMGLIISLLCVYFWLERTMPRIPTLKNLEDLVTEYHGNFSAWSGVSKGLAESLQPDYSERLCLVSEIPPKGGALGEGPGASPCNQHSPYWAPPCYTLKPET'
|
| 54 |
+
# # motifs for gamma chain
|
| 55 |
+
# target_motifs: "125, 126, 147, 149, 181, 182, 204, 229, 230, 231, 233"
|
| 56 |
+
|
| 57 |
+
# IL2RA alpha chain
|
| 58 |
+
target_sequence: 'MDSYLLMWGLLTFIMVPGCQAELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLCTGNSSHSSWDNQCQCTSSATRNTTKQVTPQPEEQKERKTTEMQSPMQPVDQASLPGHCREPPPWENEATERIYHFVVGQMVYYQCVQGYRALHRGPAESVCKMTHGKTRWTQPQLICTGEMETSQFPGEEKPQASPEGRPESETSCLVTTTDFQIQTEMAATMETSIFTTEYQVAVAGCVFLLISVLLLSGLTWQRRQRKSRRTI'
|
| 59 |
+
# motifs for alpha chain
|
| 60 |
+
target_motifs: "22, 23, 25, 27, 46, 48, 50, 51, 56, 57, 59, 60, 62, 63, 64, 66, 78, 139, 141"
|
| 61 |
+
|
| 62 |
+
# # IL2RB beta chain
|
| 63 |
+
# target_sequence: 'MAAPALSWRLPLLILLLPLATSWASAAVNGTSQFTCFYNSRANISCVWSQDGALQDTSCQVHAWPDRRRWNQTCELLPVSQASWACNLILGAPDSQKLTTVDIVTLRVLCREGVRWRVMAIQDFKPFENLRLMAPISLQVVHVETHRCNISWEISQASHYFERHLEFEARTLSPGHTWEEAPLLTLKQKQEWICLETLTPDTQYEFQVRVKPLQGEFTTWSPWSQPLAFRTKPAALGKDTIPWLGHLLVGLSGAFGFIILVYLLINCRNTGPWLKKVLKCNTPDPSKFFSQLSSEHGGDVQKWLSSPFPSSSFSPGGLAPEISPLEVLERDKVTQLLLQQDKVPEPASLSSNHSLTSCFTNQGYFFFHLPDALEIEACQVYFTYDPYSEEDPDEGVAGAPTGSSPQPLQPLSGEDDAYCTFPSRDDLLLFSPSLLGGPSPPSTAPGGSGAGEERMPPSLQERVPRDWDPQPLGPPTPGVPDLVDFQPPPELVLREAGEEVPDAGPREGVSFPWSRPPGQGEFRALNARLPLNTDAYLSLQELQGQDPTHLV'
|
| 64 |
+
# # motifs for beta chain
|
| 65 |
+
# target_motifs: "67, 68, 95, 96, 97, 99, 100, 101, 102, 127, 159, 160, 162, 164, 214"
|
| 66 |
+
|
| 67 |
+
disable_ema: False
|
| 68 |
+
generate_samples: True
|
| 69 |
+
generated_samples_path: ''
|
| 70 |
+
max_samples: 50000
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
training:
|
| 74 |
+
ema: 0.9999
|
| 75 |
+
antithetic_sampling: True
|
| 76 |
+
importance_sampling: False
|
| 77 |
+
sampling_eps: 1e-3
|
| 78 |
+
change_of_variables: False
|
| 79 |
+
compute_loss_on_pad_tokens: True
|
| 80 |
+
use_simple_ce_loss: False # Ignore ELBO; just use CE
|
| 81 |
+
guidance: null # Can turn off with `training.guidance: null`
|
| 82 |
+
# cond_dropout: 0.0
|
| 83 |
+
|
| 84 |
+
optim:
|
| 85 |
+
weight_decay: 1e-4
|
| 86 |
+
lr: 1e-5
|
| 87 |
+
beta1: 0.9
|
| 88 |
+
beta2: 0.999
|
| 89 |
+
eps: 1e-8
|
| 90 |
+
|
| 91 |
+
trainer:
|
| 92 |
+
_target_: lightning.Trainer
|
| 93 |
+
accelerator: cuda
|
| 94 |
+
num_nodes: 1
|
| 95 |
+
devices: 2 # ${device_count:}
|
| 96 |
+
accumulate_grad_batches: 1 # ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
|
| 97 |
+
gradient_clip_val: 1.0
|
| 98 |
+
precision: 'bf16-mixed'
|
| 99 |
+
num_sanity_val_steps: 2
|
| 100 |
+
# max_epochs: 10
|
| 101 |
+
max_steps: 1652000
|
| 102 |
+
log_every_n_steps: 100
|
| 103 |
+
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
|
| 104 |
+
limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
|
| 105 |
+
val_check_interval: 16520 # 2545
|
| 106 |
+
|
| 107 |
+
wandb:
|
| 108 |
+
project: moPPIt-v2
|
| 109 |
+
job_type: model-training
|
| 110 |
+
name: protein_medium_100epochs_lr1e-5_gradclip1_wd1e-4_dropout0.1 #epochs10_lr3e-4_bsz8_64-true_all-params_gradclip1_beta-one0.9_beta-two0.999
|
| 111 |
+
id: ${.name}
|
| 112 |
+
|
| 113 |
+
hydra:
|
| 114 |
+
run:
|
| 115 |
+
dir: ./outputs/${wandb.name} # ./outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
|
| 116 |
+
job:
|
| 117 |
+
chdir: true
|
| 118 |
+
|
| 119 |
+
checkpointing:
|
| 120 |
+
# Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
|
| 121 |
+
save_dir: ${cwd:}
|
| 122 |
+
# Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
|
| 123 |
+
resume_from_ckpt: False
|
| 124 |
+
resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# target_sequence: 'MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD'
|
| 128 |
+
# target_motifs: '305-313' # P53_1
|
| 129 |
+
# target_motifs: '371-382' # P53_2
|
| 130 |
+
# target_motifs: '351-393' # P53_3
|
| 131 |
+
# target_motifs: '210-230' # P53_4
|
| 132 |
+
# target_sequence: 'MLQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKTLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEFKTQPVHSPPP'
|
| 133 |
+
# target_motifs: '28-39' # NCAM1_ECD
|