diff --git a/configs/callbacks/checkpoint_every_n_steps.yaml b/configs/callbacks/checkpoint_every_n_steps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..77e1b0d7716623dacca23a3b9a03ac6711e656b3 --- /dev/null +++ b/configs/callbacks/checkpoint_every_n_steps.yaml @@ -0,0 +1,8 @@ +checkpoint_every_n_steps: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + save_top_k: -1 # Do not save any "best" models; this callback is being used to save every n train steps + save_last: True # save model as ${save_dir}/checkpoints/last.ckpt + dirpath: ${checkpointing.save_dir}/checkpoints + verbose: True + auto_insert_metric_name: False + # every_n_train_steps: 500 diff --git a/configs/callbacks/checkpoint_monitor.yaml b/configs/callbacks/checkpoint_monitor.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ad86774a76f10907f82a087610dd175eac0ae951 --- /dev/null +++ b/configs/callbacks/checkpoint_monitor.yaml @@ -0,0 +1,10 @@ +checkpoint_monitor: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + monitor: val/nll # name of the logged metric which determines when model is improving + mode: min # can be "max" or "min" + save_top_k: 1 # save k best models (determined by above metric) + save_last: False # True = additionally always save model from last epoch + dirpath: ${checkpointing.save_dir}/checkpoints + filename: best + auto_insert_metric_name: False + verbose: True diff --git a/configs/callbacks/learning_rate_monitor.yaml b/configs/callbacks/learning_rate_monitor.yaml new file mode 100644 index 0000000000000000000000000000000000000000..01e0d7d97cf9a7cfe8d38568d40e43d0d7a5441c --- /dev/null +++ b/configs/callbacks/learning_rate_monitor.yaml @@ -0,0 +1,3 @@ +learning_rate_monitor: + _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: step diff --git a/configs/classifier_model/dimamba-classifier.yaml b/configs/classifier_model/dimamba-classifier.yaml new file mode 100644 index 0000000000000000000000000000000000000000..79db32dd2c01f8155f5820c9e5cd5828466ff42e --- /dev/null +++ b/configs/classifier_model/dimamba-classifier.yaml @@ -0,0 +1,14 @@ +name: dimamba +type: dimamba +hidden_size: 256 +cond_dim: 128 +length: ${model.length} # Same length as diffusion model +n_blocks: 8 +scale_by_sigma: True +dropout: 0.1 +tie_word_embeddings: False +bidirectional: True, +bidirectional_strategy: add +bidirectional_weight_tie: True +num_classes: ${data.num_classes} +pooling: mean diff --git a/configs/classifier_model/hyenadna-classifier.yaml b/configs/classifier_model/hyenadna-classifier.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b17a4076deaffb8e85e8e8ea1ecfd2a59060b93c --- /dev/null +++ b/configs/classifier_model/hyenadna-classifier.yaml @@ -0,0 +1,4 @@ +name: hyena-32k +type: hyenadna +hyena_model_name_or_path: ??? +n_layer: 4 \ No newline at end of file diff --git a/configs/classifier_model/small-classifier.yaml b/configs/classifier_model/small-classifier.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d5d577e354fcf2bd3271c58185ccee82e92f7127 --- /dev/null +++ b/configs/classifier_model/small-classifier.yaml @@ -0,0 +1,11 @@ +name: small +type: ddit +hidden_size: 768 +cond_dim: 128 +length: ${model.length} # Same length as diffusion model +n_blocks: 12 +n_heads: 12 +scale_by_sigma: True +dropout: 0.1 +num_classes: ${data.num_classes} +pooling: mean diff --git a/configs/classifier_model/tiny-classifier.yaml b/configs/classifier_model/tiny-classifier.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9b868ddef3d0188a4ad878c068991b457a7eb802 --- /dev/null +++ b/configs/classifier_model/tiny-classifier.yaml @@ -0,0 +1,11 @@ +name: tiny +type: ddit +hidden_size: 512 +cond_dim: 128 +length: ${model.length} # Same length as diffusion model +n_blocks: 8 +n_heads: 8 +scale_by_sigma: True +dropout: 0.1 +num_classes: ${data.num_classes} +pooling: mean diff --git a/configs/classifier_model/tiny-dimamba-classifier.yaml b/configs/classifier_model/tiny-dimamba-classifier.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a75bdaed7f9429e682d75c3e6fd581179fb8fd0d --- /dev/null +++ b/configs/classifier_model/tiny-dimamba-classifier.yaml @@ -0,0 +1,14 @@ +name: tiny +type: dimamba +hidden_size: 128 +cond_dim: 128 +length: ${model.length} # Same length as diffusion model +n_blocks: 4 +scale_by_sigma: True +dropout: 0.1 +tie_word_embeddings: False +bidirectional: True, +bidirectional_strategy: add +bidirectional_weight_tie: True +num_classes: ${data.num_classes} +pooling: mean diff --git a/configs/config.yaml b/configs/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5e7b28c7ac43dfe29279e8ac686008724c13df35 --- /dev/null +++ b/configs/config.yaml @@ -0,0 +1,129 @@ +defaults: + - _self_ + - /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor] + - /data: protein + - /model: small + - /strategy: ddp + - /noise: loglinear + - /lr_scheduler: cosine_decay_warmup # constant_warmup + - /classifier_model: null + - /guidance: null + +mode: ppl_eval # train / train_classifier / ppl_eval +diffusion: uniform # absorbing_state / uniform +backbone: dit # dit / dimamba / ar +classifier_backbone: null +parameterization: d3pm # subs / d3pm / ar +time_conditioning: True # UDLM is conditioned on time +subs_masking: False +zero_recon_loss: True # Use for UDLM +T: 0 # 0 (continuous time) / 1000 + +is_vision: False +seed: 42 + +loader: + global_batch_size: 512 + eval_global_batch_size: ${.global_batch_size} + # Note: batch_size and eval_batch_size are **per machine** + batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}} + eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}} + num_workers: 0 # ${eval:"len(__import__('os').sched_getaffinity(0))"} + pin_memory: True + persistent_workers: False # True + +sampling: + use_cache: True + steps: 32 + # Note: batch_size is **per machine** + batch_size: 1 # ${loader.eval_batch_size} + num_sample_batches: 10 # Total samples: `num_gpus` * `batch_size` * `num_sample_batches` + use_float64: False + +eval: + checkpoint_path: '/home/tc415/discrete-diffusion-guidance/outputs/peptide/2024.12.31/122818/checkpoints/best.ckpt' # Used to evaluate a checkpoint after training. + # target_sequence: 'MSGIALSRLAQERKAWRKDHPFGFVAVPTKNPDGTMNLMNWECAIPGKKGTPWEGGLFKLRMLFKDDYPSSPPKCKFEPPLFHPNVYPSGTVCLSILEEDKDWRPAITIKQILLGIQELLNEPNIQDPAQAEAYTIYCQNRVEYEKRVRAQAKKFAPS' + # target_motifs: '123-127' # UBC9 + # target_sequence: 'MAMAEGERTECAEPPRDEPPADGALKRAEELKTQANDYFKAKDYENAIKFYSQAIELNPSNAIYYGNRSLAYLRTECYGYALGDATRAIELDKKYIKGYYRRAASNMALGKFRAALRDYETVVKVKPHDKDAKMKYQECNKIVKQKAFERAIAGDEHKRSVVDSLDIESMTIEDEYSGPKLEDGKVTISFMKELMQWYKDQKKLHRKCAYQILVQVKEVLSKLSTLVETTLKETEKITVCGDTHGQFYDLLNIFELNGLPSETNPYIFNGDFVDRGSFSVEVILTLFGFKLLYPDHFHLLRGNHETDNMNQIYGFEGEVKAKYTAQMYELFSEVFEWLPLAQCINGKVLIMHGGLFSEDGVTLDDIRKIERNRQPPDSGPMCDLLWSDPQPQNGRSISKRGVSCQFGPDVTKAFLEENNLDYIIRSHEVKAEGYEVAHGGRCVTVFSAPNYCDQMGNKASYIHLQGSDLRPQFHQFTAVPHPNVKPMAYANTLLQLGMM' + # target_motifs: '94-100' # PPP5 + # target_sequence: 'MRHSKRTYCPDWDDKDWDYGKWRSSSSHKRRKRSHSSAQENKRCKYNHSKMCDSHYLESRSINEKDYHSRRYIDEYRNDYTQGCEPGHRQRDHESRYQNHSSKSSGRSGRSSYKSKHRIHHSTSHRRSHGKSHRRKRTRSVEDDEEGHLICQSGDVLSARYEIVDTLGEGAFGKVVECIDHKAGGRHVAVKIVKNVDRYCEAARSEIQVLEHLNTTDPNSTFRCVQMLEWFEHHGHICIVFELLGLSTYDFIKENGFLPFRLDHIRKMAYQICKSVNFLHSNKLTHTDLKPENILFVQSDYTEAYNPKIKRDERTLINPDIKVVDFGSATYDDEHHSTLVSTRHYRAPEVILALGWSQPCDVWSIGCILIEYYLGFTVFPTHDSKEHLAMMERILGPLPKHMIQKTRKRKYFHHDRLDWDEHSSAGRYVSRRCKPLKEFMLSQDVEHERLFDLIQKMLEYDPAKRITLREALKHPFFDLLKKSI' + # target_motifs: '336-342' # CLK1 + # target_sequence: 'MEYHQPEDPAPGKAGTAEAVIPENHEVLAGPDEHPQDTDARDADGEAREREPADQALLPSQCGDNLESPLPEASSAPPGPTLGTLPEVETIRACSMPQELPQSPRTRQPEPDFYCVKWIPWKGEQTPIITQSTNGPCPLLAIMNILFLQWKVKLPPQKEVITSDELMAHLGNCLLSIKPQEKSEGLQLNFQQNVDDAMTVLPKLATGLDVNVRFTGVSDFEYTPECSVFDLLGIPLYHGWLVDPQSPEAVRAVGKLSYNQLVERIITCKHSSDTNLVTEGLIAEQFLETTAAQLTYHGLCELTAAAKEGELSVFFRNNHFSTMTKHKSHLYLLVTDQGFLQEEQVVWESLHNVDGDSCFCDSDFHLSHSLGKGPGAEGGSGSPETQLQVDQDYLIALSLQQQQPRGPLGLTDLELAQQLQQEEYQQQQAAQPVRMRTRVLSLQGRGATSGRPAGERRQRPKHESDCILL' + # target_motifs: '202-210' # MINDY1 + # target_sequence: 'MTGNAGEWCLMESDPGVFTELIKGFGCRGAQVEEIWSLEPENFEKLKPVHGLIFLFKWQPGEEPAGSVVQDSRLDTIFFAKQVINNACATQAIVSVLLNCTHQDVHLGETLSEFKEFSQSFDAAMKGLALSNSDVIRQVHNSFARQQMFEFDTKTSAKEEDAFHFVSYVPVNGRLYELDGLREGPIDLGACNQDDWISAVRPVIEKRIQKYSEGEIRFNLMAIVSDRKMIYEQKIAELQRQLAEEEPMDTDQGNSMLSAIQSEVAKNQMLIEEEVQKLKRYKIENIRRKHNYLPFIMELLKTLAEHQQLIPLVEKAKEKQNAKKAQETK' + # target_motifs: '152-157' # UCHL5 + # target_sequence: 'MSSGCQKTTTSKSIPTRWVTINDATHMPHDYSTTPGGTPFIITPGGTRIIYDRQFLLECRTSPLARTPPYSLPDIPGVTSPPSKHIINVKAHNGEPLNNNIAAPADKSTGDDAQFEMDI' + # target_motifs: '40-50' # 4E-BP2 + # target_sequence: 'MASTDYSTYSQAAAQQGYSAYTAQPTQGYAQTTQAYGQQSYGTYGQPTDVSYTQAQTTATYGQTAYATSYGQPPTGYTTPTAPQAYSQPVQGYGTGAYDTTTATVTTTQASYAAQSAYGTQPAYPAYGQQPAATAPTRPQDGNKPTETSQPQSSTGGYNQPSLGYGQSNYSYPQVPGSYPMQPVTAPPSYPPTSYSSTQPTSYDQSSYSQQNTYGQPSSYGQQSSYGQQSSYGQQPPTSYPPQTGSYSQAPSQYSQQSSSYGQQNPSYDSVRRGAWGNNMNSGLNKSPPLGGAQTISKNTEQRPQPDPYQILGPTSSRLANPGSGQIQLWQFLLELLSDSANASCITWEGTNGEFKMTDPDEVARRWGERKSKPNMNYDKLSRALRYYYDKNIMTKVHGKRYAYKFDFHGIAQALQPHPTESSMYKYPSDISYMPSYHAHQQKVNFVPPHPSSMPVTSSSFFGAASQYWTSPTGGIYPNPNVPRHPNTHVPSHLGSYY' + # target_motifs: '323-330' # EWS::FLI1 + target_sequence: 'MLQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESL' + target_motifs: '415-430' # NCAM1_IG + # target_sequence: 'TPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEFKTQPVQGEPSAPKLEGQMGEDGNSIKVNLIKQDDGGSPIRHYLVRYRALSSEWKPEIRLPSGSDHVMLKSLDWNAEYEVYVVAENQQGKSKAAHFVFRTSAQP' + # target_motifs: '98-108' # NCAM1_FN3 + + disable_ema: False + generate_samples: True + generated_samples_path: '' + max_samples: 50_000 + +training: + ema: 0.9999 + antithetic_sampling: True + importance_sampling: False + sampling_eps: 1e-3 + change_of_variables: False + compute_loss_on_pad_tokens: True + use_simple_ce_loss: False # Ignore ELBO; just use CE + guidance: null # Can turn off with `training.guidance: null` + # cond_dropout: 0.0 + +optim: + weight_decay: 1e-4 + lr: 1e-5 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + +trainer: + _target_: lightning.Trainer + accelerator: cuda + num_nodes: 1 + devices: 2 # ${device_count:} + accumulate_grad_batches: 1 # ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}} + gradient_clip_val: 1.0 + precision: 'bf16-mixed' + num_sanity_val_steps: 2 + # max_epochs: 10 + max_steps: 1652000 + log_every_n_steps: 100 + limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run + limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run + val_check_interval: 16520 # 2545 + +wandb: + project: moPPIt-v2 + job_type: model-training + name: protein_medium_100epochs_lr1e-5_gradclip1_wd1e-4_dropout0.1 #epochs10_lr3e-4_bsz8_64-true_all-params_gradclip1_beta-one0.9_beta-two0.999 + id: ${.name} + +hydra: + run: + dir: ./outputs/${wandb.name} # ./outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S} + job: + chdir: true + +checkpointing: + # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is + save_dir: ${cwd:} + # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath` + resume_from_ckpt: False + resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt + + + # target_sequence: 'MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD' + # target_motifs: '305-313' # P53_1 + # target_motifs: '371-382' # P53_2 + # target_motifs: '351-393' # P53_3 + # target_motifs: '210-230' # P53_4 + # target_sequence: 'MLQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKTLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEFKTQPVHSPPP' + # target_motifs: '28-39' # NCAM1_ECD \ No newline at end of file diff --git a/configs/data/amazon_polarity.yaml b/configs/data/amazon_polarity.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8d55bca3d92091f148d041a7297c9a1d7c445bcc --- /dev/null +++ b/configs/data/amazon_polarity.yaml @@ -0,0 +1,10 @@ +train: amazon_polarity +valid: amazon_polarity +tokenizer_name_or_path: bert-base-uncased +cache_dir: /share/kuleshov/ssahoo/textdiffusion/data +wrap: False +streaming: False +override_cache: False +add_special_tokens: True +label_col: label +num_classes: 2 \ No newline at end of file diff --git a/configs/data/cifar10.yaml b/configs/data/cifar10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..67b9b8c595c968c0bab7e942e0436ea086792745 --- /dev/null +++ b/configs/data/cifar10.yaml @@ -0,0 +1,11 @@ +train: ??? # (Local) Path to CIFAR-10 training data +valid: ??? # (Local) Path to CIFAR-10 validation data +label_col: labels +num_classes: 10 +streaming: False +size: 1024 +length: 3072 +add_special_tokens: True +add_mask_token: True +tokenizer_name_or_path: raw_pixels + diff --git a/configs/data/lm1b.yaml b/configs/data/lm1b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..89100bcbe8251e384386c96866dea5652c01b699 --- /dev/null +++ b/configs/data/lm1b.yaml @@ -0,0 +1,8 @@ +train: lm1b +valid: lm1b +tokenizer_name_or_path: bert-base-uncased +cache_dir: /share/kuleshov/ssahoo/textdiffusion/data +wrap: False +streaming: False +override_cache: False +add_special_tokens: True \ No newline at end of file diff --git a/configs/data/peptide.yaml b/configs/data/peptide.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f9c81831dad7ad7efd8f8fbca3b4a9564b0b0a6a --- /dev/null +++ b/configs/data/peptide.yaml @@ -0,0 +1,8 @@ +train: peptide +valid: peptide +tokenizer_name_or_path: facebook/esm2_t33_650M_UR50D +cache_dir: /home/tc415/discrete-diffusion-guidance/dataset +wrap: False +streaming: False +override_cache: False +add_special_tokens: True \ No newline at end of file diff --git a/configs/data/protein.yaml b/configs/data/protein.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a48977a851d741e1a9ccac65f8e23d151a721f8c --- /dev/null +++ b/configs/data/protein.yaml @@ -0,0 +1,8 @@ +train: protein_400k +valid: protein_400k +tokenizer_name_or_path: facebook/esm2_t33_650M_UR50D +cache_dir: /home/tc415/discrete-diffusion-guidance/dataset +wrap: False +streaming: False +override_cache: False +add_special_tokens: True \ No newline at end of file diff --git a/configs/data/qm9.yaml b/configs/data/qm9.yaml new file mode 100644 index 0000000000000000000000000000000000000000..12927a71c18bb520fb20e9243967cbc2bf75cd4d --- /dev/null +++ b/configs/data/qm9.yaml @@ -0,0 +1,11 @@ +train: qm9 +valid: qm9 +tokenizer_name_or_path: yairschiff/qm9-tokenizer +cache_dir: /share/kuleshov/ssahoo/textdiffusion/data +wrap: False +streaming: False +override_cache: False +add_special_tokens: True +label_col: qed +label_col_pctile: 90 +num_classes: 2 diff --git a/configs/data/ten_species.yaml b/configs/data/ten_species.yaml new file mode 100644 index 0000000000000000000000000000000000000000..38c59ddb2e7a731000edf6dd08d24df5625c7b88 --- /dev/null +++ b/configs/data/ten_species.yaml @@ -0,0 +1,11 @@ +train: ten_species +valid: ten_species +tokenizer_name_or_path: kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16 +cache_dir: /share/kuleshov/ssahoo/textdiffusion/data +wrap: False +streaming: False +override_cache: False +add_special_tokens: False +label_col: species_label +num_classes: 10 +rc_aug: False \ No newline at end of file diff --git a/configs/data/text8.yaml b/configs/data/text8.yaml new file mode 100644 index 0000000000000000000000000000000000000000..49661845cf20afe6cc5cb704a1b9d1810a5cc4d8 --- /dev/null +++ b/configs/data/text8.yaml @@ -0,0 +1,9 @@ +# TODO: When using this dataset, set model.length = 256 to match D3PM setup +train: text8 +valid: text8 +tokenizer_name_or_path: text8 +cache_dir: /share/kuleshov/ssahoo/textdiffusion/data +wrap: True +streaming: False +override_cache: False +add_special_tokens: False \ No newline at end of file diff --git a/configs/guidance/cbg.yaml b/configs/guidance/cbg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3b8ccfc4573f5666cebfb4dad40e63da4269c747 --- /dev/null +++ b/configs/guidance/cbg.yaml @@ -0,0 +1,5 @@ +method: cbg +condition: 0 +classifier_checkpoint_path: '/home/tc415/discrete-diffusion-guidance/model_path/finetune_bindevaluator_0/model-epoch=30-val_mcc=0.60-val_loss=0.51.ckpt' +gamma: 2.0 +use_approx: False # use first-order approximation \ No newline at end of file diff --git a/configs/guidance/cfg.yaml b/configs/guidance/cfg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b24943f2890897f5832b2f88d0e4201cc47b3343 --- /dev/null +++ b/configs/guidance/cfg.yaml @@ -0,0 +1,3 @@ +method: cfg +condition: 0 +gamma: 1.0 diff --git a/configs/guidance/fudge.yaml b/configs/guidance/fudge.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5c8f253a51057491e7e95b4cc843bc7a129832dd --- /dev/null +++ b/configs/guidance/fudge.yaml @@ -0,0 +1,5 @@ +method: fudge +condition: 0 +classifier_checkpoint_path: '' +topk: 20 +gamma: 1.0 diff --git a/configs/guidance/nos.yaml b/configs/guidance/nos.yaml new file mode 100644 index 0000000000000000000000000000000000000000..21e5b1c05fe0fe1b8ad2e49b6fb23fbcbbb496f5 --- /dev/null +++ b/configs/guidance/nos.yaml @@ -0,0 +1,6 @@ +method: nos +condition: 0 +classifier_checkpoint_path: '' +num_nos_steps: 1 +nos_step_size: 0.1 +nos_stability_coef: 0.01 diff --git a/configs/guidance/pplm.yaml b/configs/guidance/pplm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c5ab5fde2e30cf56d528a303c29a197ecd74baff --- /dev/null +++ b/configs/guidance/pplm.yaml @@ -0,0 +1,6 @@ +method: pplm +condition: 0 +classifier_checkpoint_path: '' +num_pplm_steps: 1 +pplm_step_size: 0.1 +pplm_stability_coef: 0.01 diff --git a/configs/lr_scheduler/constant_warmup.yaml b/configs/lr_scheduler/constant_warmup.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0bfcbe3213ae1d1e568b1bca3846c1924bd62cce --- /dev/null +++ b/configs/lr_scheduler/constant_warmup.yaml @@ -0,0 +1,2 @@ +_target_: transformers.get_constant_schedule_with_warmup +num_warmup_steps: 2500 \ No newline at end of file diff --git a/configs/lr_scheduler/cosine_decay_warmup.yaml b/configs/lr_scheduler/cosine_decay_warmup.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5fc9e213d7f1bd3e45d059477463ca83467a769b --- /dev/null +++ b/configs/lr_scheduler/cosine_decay_warmup.yaml @@ -0,0 +1,7 @@ +_target_: utils.CosineDecayWarmupLRScheduler +t_in_epochs: False +t_initial: ${eval:${trainer.max_steps}-${.warmup_t}} +warmup_prefix: True +warmup_lr_init: 1e-7 +warmup_t: ${eval:0.1*${trainer.max_steps}} +lr_min: 1e-7 diff --git a/configs/model/dimamba.yaml b/configs/model/dimamba.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dddffb2d758f874c183d5c5267e842189b9094e6 --- /dev/null +++ b/configs/model/dimamba.yaml @@ -0,0 +1,12 @@ +name: dimamba +type: dimamba +hidden_size: 256 +cond_dim: 128 +length: 32768 +n_blocks: 8 +scale_by_sigma: True +dropout: 0.1 +tie_word_embeddings: False +bidirectional: True, +bidirectional_strategy: add +bidirectional_weight_tie: True diff --git a/configs/model/fudge_predictor.yaml b/configs/model/fudge_predictor.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a8ef38547b309eeeec0e2a46166edd89548eade --- /dev/null +++ b/configs/model/fudge_predictor.yaml @@ -0,0 +1,4 @@ +name: fudge_predictor +type: lstm +hidden_dim: 300 +length: 1024 \ No newline at end of file diff --git a/configs/model/hf.yaml b/configs/model/hf.yaml new file mode 100644 index 0000000000000000000000000000000000000000..937fd647cf26bebae55a69f2cde22064a0f1612d --- /dev/null +++ b/configs/model/hf.yaml @@ -0,0 +1,2 @@ +pretrained_model_name_or_path: null +length: 128 diff --git a/configs/model/medium.yaml b/configs/model/medium.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f976f9cfd89abc57c9c266cf4a35df1801a964b4 --- /dev/null +++ b/configs/model/medium.yaml @@ -0,0 +1,10 @@ +name: medium +type: ddit +hidden_size: 1024 +cond_dim: 128 +length: 4096 +n_blocks: 24 +n_heads: 16 +scale_by_sigma: True +dropout: 0.1 +tie_word_embeddings: False \ No newline at end of file diff --git a/configs/model/small.yaml b/configs/model/small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..21f18c543e35fa40e9550b909ea656f445b7879e --- /dev/null +++ b/configs/model/small.yaml @@ -0,0 +1,11 @@ +name: small +type: ddit +hidden_size: 768 +cond_dim: 128 +length: null +length_range: '25,27,28,31,35,43-49' +n_blocks: 12 +n_heads: 12 +scale_by_sigma: True +dropout: 0.1 +tie_word_embeddings: False \ No newline at end of file diff --git a/configs/model/tiny.yaml b/configs/model/tiny.yaml new file mode 100644 index 0000000000000000000000000000000000000000..34a1f136d2a33db6516bde2de6674d7e100d49d8 --- /dev/null +++ b/configs/model/tiny.yaml @@ -0,0 +1,10 @@ +name: tiny +type: ddit +hidden_size: 512 +cond_dim: 128 +length: 1024 +n_blocks: 8 +n_heads: 8 +scale_by_sigma: True +dropout: 0.1 +tie_word_embeddings: False \ No newline at end of file diff --git a/configs/model/unet.yaml b/configs/model/unet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3bf1961f986b8311782bcd490ec4cfdc5a60580b --- /dev/null +++ b/configs/model/unet.yaml @@ -0,0 +1,19 @@ +name: unet +type: unet +ch: 128 +num_res_blocks: 2 +num_scales: 4 +ch_mult: [1, 2, 2, 2] +input_channels: 3 +output_channels: -1 # determined by vocab_size +scale_count_to_put_attn: 1 # at 16 res +data_min_max: [0, 255] # No need currently +dropout: 0.1 +skip_rescale: True +time_conditioning: True # Whether to add in time embeddings +time_scale_factor: 1000 +time_embed_dim: ${.ch} +fix_logistic: False +size: ${data.size} +cond_dim: ${.ch} +length: ${data.length} \ No newline at end of file diff --git a/configs/model/unet_campbell.yaml b/configs/model/unet_campbell.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4ead66e787299cc8c928449863667d951116ae94 --- /dev/null +++ b/configs/model/unet_campbell.yaml @@ -0,0 +1,19 @@ +name: unet +type: unet +ch: 128 +num_res_blocks: 2 +num_scales: 4 +ch_mult: [1, 2, 2, 2] +input_channels: 3 +output_channels: -1 # determined by input_channels * 2 +scale_count_to_put_attn: 1 # at 16 res +data_min_max: [0, 255] # No need currently, determined by [0, vocab_size] +dropout: 0.1 +skip_rescale: True +time_conditioning: True # Whether to add in time embeddings +time_scale_factor: 1000 +time_embed_dim: ${.ch} +fix_logistic: False +size: ${data.size} +cond_dim: ${.ch} +length: ${data.length} diff --git a/configs/noise/ar.yaml b/configs/noise/ar.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ea87023b328f9d078d5b4d4715b3fbfe270d1db7 --- /dev/null +++ b/configs/noise/ar.yaml @@ -0,0 +1,2 @@ +type: ar +scale: 6.0 \ No newline at end of file diff --git a/configs/noise/linear.yaml b/configs/noise/linear.yaml new file mode 100644 index 0000000000000000000000000000000000000000..12fffce86b73e79ad8bffe6fca4d3a9a7d8a808a --- /dev/null +++ b/configs/noise/linear.yaml @@ -0,0 +1,3 @@ +type: linear +sigma_min: 1e-3 +sigma_max: 7.0 \ No newline at end of file diff --git a/configs/noise/loglinear.yaml b/configs/noise/loglinear.yaml new file mode 100644 index 0000000000000000000000000000000000000000..04c914f32ffe43e52b35ea31fb4617d366983393 --- /dev/null +++ b/configs/noise/loglinear.yaml @@ -0,0 +1,3 @@ +type: loglinear +sigma_min: 1e-4 +sigma_max: 20 \ No newline at end of file diff --git a/configs/noise/polynomial.yaml b/configs/noise/polynomial.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7218191733ac458ac2d9de1411fe5de7cde42a3c --- /dev/null +++ b/configs/noise/polynomial.yaml @@ -0,0 +1,5 @@ +type: polynomial +a: -3 +b: 5 +c: -4 +eps: 1e-3 \ No newline at end of file diff --git a/configs/strategy/ddp.yaml b/configs/strategy/ddp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d027eee4d9e2bc923154b045c53acb4c31565470 --- /dev/null +++ b/configs/strategy/ddp.yaml @@ -0,0 +1,2 @@ +_target_: lightning.pytorch.strategies.DDPStrategy +find_unused_parameters: false diff --git a/configs/strategy/fsdp.yaml b/configs/strategy/fsdp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7ce2d35df6f1c1221617006bc565844bb10c77f5 --- /dev/null +++ b/configs/strategy/fsdp.yaml @@ -0,0 +1,3 @@ +# TODO(yair): Currently not compatible with grad clipping +_target_: lightning.pytorch.strategies.FSDPStrategy +sharding_strategy: SHARD_GRAD_OP diff --git a/guidance_eval/__init__.py b/guidance_eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/guidance_eval/amazon_polarity_eval.py b/guidance_eval/amazon_polarity_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..4939adc4ee7218f775f466a9cb2490b09bac80ba --- /dev/null +++ b/guidance_eval/amazon_polarity_eval.py @@ -0,0 +1,228 @@ +import collections +import json +import os + +import hydra +import lightning as L +import omegaconf +import pandas as pd +import rdkit +import rich.syntax +import rich.tree +import spacy +import torch +import transformers +# from evaluate import load +from nltk.util import ngrams +from tqdm.auto import tqdm + +import dataloader +import diffusion +import eval_utils + +rdkit.rdBase.DisableLog('rdApp.error') + +omegaconf.OmegaConf.register_new_resolver( + 'cwd', os.getcwd) +omegaconf.OmegaConf.register_new_resolver( + 'device_count', torch.cuda.device_count) +omegaconf.OmegaConf.register_new_resolver( + 'eval', eval) +omegaconf.OmegaConf.register_new_resolver( + 'div_up', lambda x, y: (x + y - 1) // y) +omegaconf.OmegaConf.register_new_resolver( + 'if_then_else', + lambda condition, x, y: x if condition else y +) + + +def _print_config( + config: omegaconf.DictConfig, + resolve: bool = True) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + + Args: + config (DictConfig): Configuration composed by Hydra. + resolve (bool): Whether to resolve reference fields of DictConfig. + """ + + style = 'dim' + tree = rich.tree.Tree('CONFIG', style=style, + guide_style=style) + + fields = config.keys() + for field in fields: + branch = tree.add(field, style=style, guide_style=style) + + config_section = config.get(field) + branch_content = str(config_section) + if isinstance(config_section, omegaconf.DictConfig): + branch_content = omegaconf.OmegaConf.to_yaml( + config_section, resolve=resolve) + + branch.add(rich.syntax.Syntax(branch_content, 'yaml')) + rich.print(tree) + +def compute_diversity(sentences): + # compute diversity + ngram_range = [2, 3, 4] + + tokenizer = spacy.load("en_core_web_sm").tokenizer + token_list = [] + for sentence in sentences: + token_list.append( + [str(token) for token in tokenizer(sentence)]) + ngram_sets = {} + ngram_counts = collections.defaultdict(int) + n_gram_repetition = {} + + for n in ngram_range: + ngram_sets[n] = set() + for tokens in token_list: + ngram_sets[n].update(ngrams(tokens, n)) + ngram_counts[n] += len(list(ngrams(tokens, n))) + n_gram_repetition[f"{n}gram_repetition"] = ( + 1 - len(ngram_sets[n]) / ngram_counts[n]) + diversity = 1 + for val in n_gram_repetition.values(): + diversity *= (1 - val) + return diversity + + +def compute_sentiment_classifier_score(sentences, eval_model_name_or_path): + tokenizer = transformers.AutoTokenizer.from_pretrained(eval_model_name_or_path) + eval_model = transformers.AutoModelForSequenceClassification.from_pretrained( + eval_model_name_or_path).to('cuda') + eval_model.eval() + + total_pos = 0 + total_neg = 0 + pbar = tqdm(sentences, desc='Classifier eval') + for sen in pbar: + # Tokenize the input text + inputs = tokenizer( + sen, + return_tensors="pt", + truncation=True, + padding=True).to('cuda') + + # Get the model predictions + with torch.no_grad(): + outputs = eval_model(**inputs) + + # Convert logits to probabilities + probs = torch.nn.functional.softmax( + outputs.logits, dim=-1) + + # Get the predicted class + predicted_class = torch.argmax(probs, dim=1).item() + if predicted_class == 1: + total_pos += 1 + else: + total_neg += 1 + pbar.set_postfix(accuracy=total_pos / (total_pos + total_neg)) + return total_pos / (total_pos + total_neg) + + +# def compute_mauve(config, tokenizer, sentences): +# os.environ["TOKENIZERS_PARALLELISM"] = "false" +# # compute mauve +# torch.cuda.empty_cache() +# mauve = load("mauve") +# human_references = [] +# +# valid_loader = dataloader.get_dataloaders( +# config, tokenizer, valid_seed=config.seed) +# +# # construct reference +# for batch_id in range(config.sampling.num_sample_batches): +# batch = next(iter(valid_loader)) +# input_ids = batch['input_ids'] +# for i in range(config.sampling.batch_size): +# idx = ( +# input_ids[i] == tokenizer.eos_token_id).nonzero( +# as_tuple=True) +# if idx[0].numel() > 0: +# idx = idx[0][0].item() +# input_ids[i, (idx + 1):] = 0 +# human_references.extend( +# tokenizer.batch_decode( +# input_ids, skip_special_tokens=True)) +# +# assert len(sentences) == len(human_references) +# +# results = mauve.compute(predictions=sentences, +# references=human_references, +# featurize_model_name=config.data.mauve_model, +# max_text_length=256, device_id=0) +# return results.mauve + + + +@hydra.main(version_base=None, config_path='../configs', + config_name='config') +def main(config: omegaconf.DictConfig) -> None: + # Reproducibility + L.seed_everything(config.seed) + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.benchmark = False + + _print_config(config, resolve=True) + print(f"Checkpoint: {config.eval.checkpoint_path}") + + tokenizer = dataloader.get_tokenizer(config) + pretrained = diffusion.Diffusion.load_from_checkpoint( + config.eval.checkpoint_path, + tokenizer=tokenizer, + config=config, logger=False) + pretrained.eval() + result_dicts = [] + samples = [] + for _ in tqdm( + range(config.sampling.num_sample_batches), + desc='Gen. batches', leave=False): + sample = pretrained.sample() + samples.extend( + pretrained.tokenizer.batch_decode(sample)) + samples = [ + s.replace('[CLS]', '').replace('[SEP]', '').replace('[PAD]', '').replace('[MASK]', '').strip() + for s in samples + ] + del pretrained # free up space for eval + + diversity_score = compute_diversity(samples) + classifier_accuracy = compute_sentiment_classifier_score( + samples, eval_model_name_or_path=config.eval.classifier_model_name_or_path) + + generative_ppl = eval_utils.compute_generative_ppl( + samples, + eval_model_name_or_path=config.eval.generative_ppl_model_name_or_path, + gen_ppl_eval_batch_size=8, + max_length=config.model.length) + + result_dicts.append({ + 'Seed': config.seed, + 'T': config.sampling.steps, + 'Num Samples': config.sampling.batch_size * config.sampling.num_sample_batches, + 'Diversity': diversity_score, + 'Accuracy': classifier_accuracy, + 'Gen. PPL': generative_ppl, + } | {k.capitalize(): v for k, v in config.guidance.items()}) + print("Guidance:", ", ".join([f"{k.capitalize()} - {v}" for k, v in config.guidance.items()])) + print(f"\tDiversity: {diversity_score:0.3f} ", + f"Accuracy: {classifier_accuracy:0.3f} ", + f"Gen. PPL: {generative_ppl:0.3f}") + print(f"Generated {len(samples)} sentences.") + with open(config.eval.generated_samples_path, 'w') as f: + json.dump( + { + 'generated_seqs': samples, + }, + f, indent=4) # type: ignore + results_df = pd.DataFrame.from_records(result_dicts) + results_df.to_csv(config.eval.results_csv_path) + + +if __name__ == '__main__': + main() diff --git a/guidance_eval/qm9_eval.py b/guidance_eval/qm9_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..3a4e0e2a694246dbee51bbb0491a4e4e16d39cfb --- /dev/null +++ b/guidance_eval/qm9_eval.py @@ -0,0 +1,208 @@ +import json +import os +import time +import typing + +import datasets +import hydra +import lightning as L +import numpy as np +import omegaconf +import pandas as pd +import rdkit +import rich.syntax +import rich.tree +import torch +from rdkit import Chem as rdChem +from rdkit.Chem import QED +from tqdm.auto import tqdm + +import dataloader +import diffusion + +rdkit.rdBase.DisableLog('rdApp.error') + +omegaconf.OmegaConf.register_new_resolver( + 'cwd', os.getcwd) +omegaconf.OmegaConf.register_new_resolver( + 'device_count', torch.cuda.device_count) +omegaconf.OmegaConf.register_new_resolver( + 'eval', eval) +omegaconf.OmegaConf.register_new_resolver( + 'div_up', lambda x, y: (x + y - 1) // y) +omegaconf.OmegaConf.register_new_resolver( + 'if_then_else', + lambda condition, x, y: x if condition else y +) + + +def _print_config( + config: omegaconf.DictConfig, + resolve: bool = True) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + + Args: + config (DictConfig): Configuration composed by Hydra. + resolve (bool): Whether to resolve reference fields of DictConfig. + """ + + style = 'dim' + tree = rich.tree.Tree('CONFIG', style=style, + guide_style=style) + + fields = config.keys() + for field in fields: + branch = tree.add(field, style=style, guide_style=style) + + config_section = config.get(field) + branch_content = str(config_section) + if isinstance(config_section, omegaconf.DictConfig): + branch_content = omegaconf.OmegaConf.to_yaml( + config_section, resolve=resolve) + + branch.add(rich.syntax.Syntax(branch_content, 'yaml')) + rich.print(tree) + + +def get_mol_property_fn( + prop: str +) -> typing.Callable[[rdChem.Mol], typing.Union[int, float]]: + if prop == 'qed': + return QED.qed + if prop == 'ring_count': + return lambda x_mol: len(rdChem.GetSymmSSSR(x_mol)) + raise NotImplementedError( + f"Property function for {prop} not implemented") + + +@hydra.main(version_base=None, config_path='../configs', + config_name='config') +def main(config: omegaconf.DictConfig) -> None: + # Reproducibility + L.seed_everything(config.seed) + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.benchmark = False + + _print_config(config, resolve=True) + print(f"Checkpoint: {config.eval.checkpoint_path}") + + qm9_dataset = datasets.load_dataset( + 'yairschiff/qm9', trust_remote_code=True, + split='train') + tokenizer = dataloader.get_tokenizer(config) + pretrained = diffusion.Diffusion.load_from_checkpoint( + config.eval.checkpoint_path, + tokenizer=tokenizer, + config=config, logger=False) + pretrained.eval() + label_col = config.data.label_col + pctile_threshold = config.data.label_col_pctile + pctile_threshold_value = np.percentile( + qm9_dataset[label_col], q=pctile_threshold) + above_threshold = np.array(qm9_dataset[label_col])[ + qm9_dataset[label_col] >= pctile_threshold_value] + below_threshold = np.array(qm9_dataset[label_col])[ + qm9_dataset[label_col] < pctile_threshold_value] + result_dicts = [] + mol_property_fn = get_mol_property_fn(label_col) + + print( + f"All - {label_col.upper()} Mean: {np.mean(qm9_dataset[label_col]):0.3f}, {label_col.upper()} Median: {np.median(qm9_dataset[label_col]):0.3f}") + print( + f"Below {pctile_threshold}%ile - {label_col.upper()} Mean: {np.mean(below_threshold):0.3f}, {label_col.upper()} Median: {np.median(below_threshold):0.3f}") + print( + f"Above {pctile_threshold}%ile - {label_col.upper()} Mean: {np.mean(above_threshold):0.3f}, {label_col.upper()} Median: {np.median(above_threshold):0.3f}") + result_dicts.append({ + 'Seed': -1, + 'T': -1, + 'Num Samples': len(qm9_dataset), + 'Valid': 1.0, + 'Unique': 1.0, + 'Novel': 1.0, + f'{label_col.upper()} Mean': np.mean(qm9_dataset[label_col]), + f'{label_col.upper()} 25%ile': np.percentile(qm9_dataset[label_col], q=25), + f'{label_col.upper()} Median': np.median(qm9_dataset[label_col]), + f'{label_col.upper()} 75%ile': np.percentile(qm9_dataset[label_col], q=75), + f'Novel {label_col.upper()} Mean': np.mean(qm9_dataset[label_col]), + f'Novel {label_col.upper()} 25%ile': np.percentile(qm9_dataset[label_col], q=25), + f'Novel {label_col.upper()} Median': np.median(qm9_dataset[label_col]), + f'Novel {label_col.upper()} 75%ile': np.percentile(qm9_dataset[label_col], q=75), + } | {k.capitalize(): -1 for k, v in config.guidance.items()}) + + samples = [] + for _ in tqdm( + range(config.sampling.num_sample_batches), + desc='Gen. batches', leave=False): + start = time.time() + sample = pretrained.sample() + # print(f"Batch took {time.time() - start:.2f} seconds.") + samples.extend( + pretrained.tokenizer.batch_decode(sample)) + invalids = [] + valids = [] + mol_property = [] + for t in samples: + t = t.replace('', '').replace('', '').replace('', '') + try: + mol = rdChem.MolFromSmiles(t) + if mol is None or len(t) == 0: + invalids.append(t) + else: + valids.append(t) + mol_property.append(mol_property_fn(mol)) + except rdkit.Chem.rdchem.KekulizeException as e: + print(e) + invalids.append(t) + valid = len(valids) + valid_pct = len(valids) / len(samples) + unique = len(set(valids)) + novel = len(set(valids) - set(qm9_dataset['canonical_smiles'])) + try: + unique_pct = unique / valid + novel_pct = novel / valid + except ZeroDivisionError: + unique_pct, novel_pct = 0., 0. + mol_property_novel = [ + mol_property_fn(rdChem.MolFromSmiles(s)) + for s in set(valids) - set(qm9_dataset['canonical_smiles']) + ] + result_dicts.append({ + 'Seed': config.seed, + 'T': config.sampling.steps, + 'Num Samples': config.sampling.batch_size * config.sampling.num_sample_batches, + 'Valid': valid_pct, + 'Unique': unique_pct, + 'Novel': novel_pct, + f'{label_col.upper()} Mean': np.mean(mol_property) if len(mol_property) > 0 else 0., + f'{label_col.upper()} 25%ile': np.percentile(mol_property, q=25) if len(mol_property) > 0 else 0., + f'{label_col.upper()} Median': np.median(mol_property) if len(mol_property) > 0 else 0., + f'{label_col.upper()} 75%ile': np.percentile(mol_property, q=75) if len(mol_property) > 0 else 0., + f'Novel {label_col.upper()} Mean': np.mean(mol_property_novel) if len(mol_property_novel) > 0 else 0., + f'Novel {label_col.upper()} 25%ile': np.percentile(mol_property_novel, q=25) if len(mol_property_novel) > 0 else 0., + f'Novel {label_col.upper()} Median': np.median(mol_property_novel) if len(mol_property_novel) > 0 else 0., + f'Novel {label_col.upper()} 75%ile': np.percentile(mol_property_novel, q=75) if len(mol_property_novel) > 0 else 0., + } | {k.capitalize(): v for k, v in config.guidance.items()}) + print("Guidance:", ", ".join([f"{k.capitalize()} - {v}" for k, v in config.guidance.items()])) + print(f"\tValid: {valid:,d} / {len(samples):,d} ({100 * valid_pct:0.2f}%) ", + f"Unique (of valid): {unique:,d} / {valid:,d} ({100 * unique_pct:0.2f}%) ", + f"Novel (of valid): {novel:,d} / {valid:,d} ({100 * novel_pct:0.2f}%)\n", + f"\t{label_col.upper()} Mean: {np.mean(mol_property) if len(mol_property) else 0.:0.3f}, {label_col.upper()} Median: {np.median(mol_property) if len(mol_property) else 0.:0.3f}\n", + f"\tNovel {label_col.upper()} Mean: {np.mean(mol_property_novel) if len(mol_property_novel) else 0.:0.3f}, Novel {label_col.upper()} Median: {np.median(mol_property_novel) if len(mol_property_novel) else 0.:0.3f}" + ) + print(f"Generated {len(samples)} sentences.") + with open(config.eval.generated_samples_path, 'w') as f: + json.dump( + { + 'valid': valids, + 'novel': list(set(valids) - set(qm9_dataset['canonical_smiles'])), + f"{label_col}_valid": mol_property, + f"{label_col}_novel": mol_property_novel, + }, + f, indent=4) # type: ignore + results_df = pd.DataFrame.from_records(result_dicts) + results_df.to_csv(config.eval.results_csv_path) + + +if __name__ == '__main__': + main() diff --git a/guidance_eval/ten_species_eval.py b/guidance_eval/ten_species_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..e50430be238612a4503ac1efaa9a73a738f3f356 --- /dev/null +++ b/guidance_eval/ten_species_eval.py @@ -0,0 +1,585 @@ +import itertools +import json +import os +import typing + +import datasets +import hydra +import lightning as L +import numpy as np +import omegaconf +import pandas as pd +import rdkit +import rich.syntax +import rich.tree +import scipy +import torch +import transformers +from sklearn.metrics import ( + f1_score, + matthews_corrcoef, + precision_score, + recall_score, + roc_auc_score +) +from tqdm.auto import tqdm + +import classifier +import custom_datasets +import dataloader +import diffusion + +rdkit.rdBase.DisableLog('rdApp.error') + +omegaconf.OmegaConf.register_new_resolver( + 'cwd', os.getcwd) +omegaconf.OmegaConf.register_new_resolver( + 'device_count', torch.cuda.device_count) +omegaconf.OmegaConf.register_new_resolver( + 'eval', eval) +omegaconf.OmegaConf.register_new_resolver( + 'div_up', lambda x, y: (x + y - 1) // y) +omegaconf.OmegaConf.register_new_resolver( + 'if_then_else', + lambda condition, x, y: x if condition else y +) + + +def _print_config( + config: omegaconf.DictConfig, + resolve: bool = True) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + + Args: + config (DictConfig): Configuration composed by Hydra. + resolve (bool): Whether to resolve reference fields of DictConfig. + """ + + style = 'dim' + tree = rich.tree.Tree('CONFIG', style=style, + guide_style=style) + + fields = config.keys() + for field in fields: + branch = tree.add(field, style=style, guide_style=style) + + config_section = config.get(field) + branch_content = str(config_section) + if isinstance(config_section, omegaconf.DictConfig): + branch_content = omegaconf.OmegaConf.to_yaml( + config_section, resolve=resolve) + + branch.add(rich.syntax.Syntax(branch_content, 'yaml')) + rich.print(tree) + + +def generate_ordered_kmers( + kmer_length: int +) -> typing.List[str]: + """ + Function that generates all kmers of a given length and orders them by their index + defined by the kmer_to_index function. + + Args: + kmer_length (int): The length of the kmers to generate + + Returns: + List[str]: A list of all kmers of the given length ordered by their index + """ + characters = ["A", "C", "G", "T"] + + kmers = ["".join(kmer) for kmer in + itertools.product(characters, + repeat=kmer_length)] + ordered_kmers = sorted(kmers, key=kmer_to_index) + + return ordered_kmers + + +def kmer_to_index(kmer: str) -> int: + """ + Function that converts a given kmer to a unique value + system. + + Args: + kmer (str): The given kmer + + Returns: + int: The associated unique value + + Example: + >>> kmer_to_index("AAC") + 1 + + """ + mapping = {"A": 0, "C": 1, "G": 2, "T": 3} + index = 0 + for char in kmer: + index = index * 4 + mapping[char] + return index + + +def compute_kmer_frequencies( + seqs: typing.List[str], kmer_length: int +) -> typing.Tuple[typing.List[float], typing.List[str]]: + """ + Computes the kmer frequencies in a list of sequences. + Each element of the output array is the frequency of a given kmer over the whole + set of sequences. + + Args: + seqs (List[str]): List of nucleotide sequences + kmer_length (int): Length of the kmers + + Returns: + List[float]: Kmer frequencies + List[str]: The kmers + + Example: + >>> sequences = ["AGCT", "AAAA"] + >>> compute_kmer_frequencies(seqs, kmer_length=1) + ([0.625, 0.125, 0.125, 0.125], ['A', 'C', 'G', 'T']) + """ + + kmer_counts: typing.Dict[str, int] = {} + count_kmers_occurrences = 0 + for seq in seqs: + for i in range(len(seq) - kmer_length + 1): + kmer = seq[i: i + kmer_length] + if kmer in kmer_counts: + kmer_counts[kmer] += 1 + else: + kmer_counts[kmer] = 1 + count_kmers_occurrences += 1 + + kmer_list = generate_ordered_kmers(kmer_length) + kmer_frequencies = [] + for kmer in kmer_list: + try: + kmer_frequencies.append( + kmer_counts[kmer] / count_kmers_occurrences) + except KeyError: + kmer_frequencies.append(0) + + return kmer_frequencies, kmer_list + + +def run_eval_pipeline( + seqs: typing.Dict[int, typing.List[str]], + num_samples_per_class: int, + train_weights_path: str, + val_weights_path: str, + eval_classifier_checkpoint_path: str, + kmer_freqs_path: str +): + # Eval pipeline + L.seed_everything(42) + + # Load classifier + with hydra.initialize(version_base=None, + config_path='../configs/'): + classifier_config = hydra.compose( + config_name='config', + overrides=[ + 'hydra.output_subdir=null', + 'hydra.job.chdir=False', + 'hydra/job_logging=disabled', + 'hydra/hydra_logging=disabled', + '+is_eval_classifier=True', + 'mode=train_classifier', + 'loader.global_batch_size=32', + 'loader.eval_global_batch_size=64', + 'loader.batch_size=2', + 'loader.eval_batch_size=4', + 'data=ten_species', + 'classifier_model=hyenadna-classifier', + 'classifier_model.hyena_model_name_or_path=LongSafari/hyenadna-small-32k-seqlen-hf', + 'classifier_backbone=hyenadna', + 'classifier_model.n_layer=8', + 'model.length=32768', + 'diffusion=null', + 'T=null', + f"eval.checkpoint_path={eval_classifier_checkpoint_path}" + ] + ) + classifier_config = omegaconf.OmegaConf.create( + classifier_config) + tokenizer = transformers.AutoTokenizer.from_pretrained( + classifier_config.data.tokenizer_name_or_path, + trust_remote_code=True) + pretrained_classifier = classifier.Classifier.load_from_checkpoint( + classifier_config.eval.checkpoint_path, + tokenizer=tokenizer, + config=classifier_config, logger=False) + pretrained_classifier.eval() + + tokenizer = dataloader.get_tokenizer(classifier_config) + _, val_dl = dataloader.get_dataloaders( + classifier_config, tokenizer, skip_train=True, + valid_seed=classifier_config.seed) + + dataset = datasets.load_dataset( + 'yairschiff/ten_species', + split='train', + # original dataset only has `train` split + chunk_length=classifier_config.model.length, + overlap=0, + trust_remote_code=True) + dataset = dataset.train_test_split( + test_size=0.05, seed=42) + train_dataset = dataset['train'] + val_dataset = dataset['test'] + + + print(f"Len of train set {len(train_dataset) * (2 ** 15):,d}") + print(f"Len of val set {len(val_dataset) * (2 ** 15):,d}") + + int_to_species = ['Homo_sapiens', 'Mus_musculus', + 'Drosophila_melanogaster', + 'Danio_rerio', + 'Caenorhabditis_elegans', + 'Gallus_gallus', 'Gorilla_gorilla', + 'Felis_catus', + 'Salmo_trutta', 'Arabidopsis_thaliana'] + + if os.path.exists(train_weights_path): + train_weights = torch.load(train_weights_path) + else: + train_weights = {k: 0 for k in range(10)} + for i in tqdm(train_dataset, leave=False): + train_weights[i['species_label']] += 1 + train_weights = { + k: v / np.sum(list(train_weights.values())) for k, v + in train_weights.items()} + torch.save(train_weights, train_weights_path) + print('Train weights:') + for k, v in train_weights.items(): + print("\t", int_to_species[k], f"{100 * v:0.2f}") + + if os.path.exists(val_weights_path): + val_weights = torch.load(val_weights_path) + else: + val_weights = {k: 0 for k in range(10)} + for i in tqdm(val_dataset, leave=False): + val_weights[i['species_label']] += 1 + val_weights = {k: v / np.sum(list(val_weights.values())) + for k, v in val_weights.items()} + torch.save(val_weights, val_weights_path) + print('\nVal weights:') + for k, v in val_weights.items(): + print("\t", int_to_species[k], f"{100 * v:0.2f}") + + + result_dict = {} + test_data = [] + + for k, v in seqs.items(): + test_data.extend( + [ + { + 'sequence': s.replace('[CLS]', '').replace( + '[BOS]', '').replace('[MASK]', '').replace( + '[SEP]', '').replace('[PAD]', '').replace( + '[UNK]', ''), + 'species_label': k + } + for s in v + ] + ) + test_dataset = custom_datasets.ten_species_dataset.TenSpeciesDataset( + split='test', + tokenizer=tokenizer, + max_length=classifier_config.model.length, + rc_aug=False, + add_special_tokens=classifier_config.data.add_special_tokens, + dataset=test_data + ) + + ## CLASSIFIER ACCURACY + test_preds = [ + pretrained_classifier.forward( + test_dataset[i]['input_ids'][None, ...].to( + 'cuda')).argmax(dim=-1).detach().item() + for i in + tqdm(range(len(test_dataset)), desc='Testing') + ] + test_preds = np.array(test_preds) + + test_labels = [] + for k, v in seqs.items(): + test_labels.extend([int(k)] * len(v)) + test_labels = np.array(test_labels) + + overall_accuracy_score = (test_preds == test_labels).sum() / test_preds.size + overall_f1_score = f1_score(y_pred=test_preds, + y_true=test_labels, + average="macro", + labels=list(range(classifier_config.data.num_classes))) + overall_mcc_score = matthews_corrcoef(y_pred=test_preds, y_true=test_labels) + + print(f"Overall Acc: {overall_accuracy_score:0.2f}") + print(f"Overall F1: {overall_f1_score:0.2f}") + print(f"Overall MCC: {overall_mcc_score:0.2f}") + result_dict['F1'] = overall_f1_score + + f1_scores = f1_score( + y_pred=test_preds, + y_true=test_labels, + average=None, + labels=list(range(classifier_config.data.num_classes))) + precision_scores = precision_score( + y_pred=test_preds, + y_true=test_labels, + average=None, + labels=list(range(classifier_config.data.num_classes))) + recall_scores = recall_score( + y_pred=test_preds, + y_true=test_labels, + average=None, + labels=list(range(classifier_config.data.num_classes))) + + species_list = ['Homo_sapiens', 'Mus_musculus', + 'Drosophila_melanogaster', + 'Danio_rerio', + 'Caenorhabditis_elegans', + 'Gallus_gallus', 'Gorilla_gorilla', + 'Felis_catus', + 'Salmo_trutta', + 'Arabidopsis_thaliana'] + for s in range(classifier_config.data.num_classes): + print(f"Class {s} - {species_list[s]}:") + print(f" F1: {f1_scores[s]:0.3f}") + print(f" Precision: {precision_scores[s]:0.3f}") + print(f" Recall: {recall_scores[s]:0.3f}") + + ## KMER SPECTRUM + kmer_lengths = [3, 6] + kmer_results = {k: [] for k in kmer_lengths} + if os.path.exists(kmer_freqs_path): + kmer_freqs = torch.load(kmer_freqs_path) + else: + kmer_freqs = {s: { + kmer_length: {'frequencies': None, + 'kmers': None} for kmer_length in + kmer_lengths} for s in range(10)} + for s in range(10): + filter_ds = val_dataset.filter( + lambda x: x['species_label'] == s, + num_proc=len(os.sched_getaffinity(0))) + print(f"Computing kmer frequencies for species class {s}") + for kmer_length in kmer_lengths: + kmer_frequencies_gt, kmer_list = compute_kmer_frequencies( + seqs=filter_ds['sequence'], + kmer_length=kmer_length + ) + kmer_freqs[s][kmer_length]['frequencies'] = kmer_frequencies_gt + kmer_freqs[s][kmer_length]['kmers'] = kmer_list + torch.save(kmer_freqs, kmer_freqs_path) + for s in range(10): + print(f"Species class {s}") + mean_js_divergence = 0 + for kmer_length in kmer_lengths: + kmer_frequencies_gt = kmer_freqs[s][kmer_length]['frequencies'] + kmer_frequencies_generated, kmer_list = compute_kmer_frequencies( + seqs=[i['sequence'] for i in test_data if + i['species_label'] == s], + kmer_length=kmer_length + ) + + js_divergence = np.sum( + scipy.spatial.distance.jensenshannon( + kmer_frequencies_gt, + kmer_frequencies_generated) + ) + kmer_results[kmer_length].append(js_divergence) + mean_js_divergence += js_divergence + print( + f"\tJS divergence with k={kmer_length} : {js_divergence}") + print( + f"\tMean JS divergence : {mean_js_divergence / len(kmer_lengths):0.2f}") + + for k, v in kmer_results.items(): + weighted_kmer_js = (np.array(v) * np.array( + list(val_weights.values()))).sum() + print( + f"Weighted mean JS divergence across classes with k={k}: {weighted_kmer_js:0.2f}") + result_dict[f"{k}mer JS"] = weighted_kmer_js + + ## DISCRIMINATOR AUROC + # Hyperparams + d_model = 128 + n_layer = 2 + + batch_size = 8 + lr = 1e-4 + epochs = 5 + + disc_data = [ + {'sequence': i['sequence'], 'species_label': 0} + for i in test_data] + for s in range(10): + filter_val_ds = val_dataset.filter( + lambda x: x['species_label'] == s, + num_proc=len(os.sched_getaffinity(0))) + indices = np.random.permutation( + np.arange(len(filter_val_ds)))[:num_samples_per_class] + disc_data.extend( + [{'sequence': i['sequence'], 'species_label': 1} + for i in filter_val_ds.select(indices)] + ) + print(f"Size of discriminator dataset: {len(disc_data)}") + disc_dataset_hf = datasets.Dataset.from_list( + disc_data) + disc_dataset_hf = disc_dataset_hf.train_test_split( + test_size=0.1, seed=42) + + disc_dataset_train = custom_datasets.ten_species_dataset.TenSpeciesDataset( + split='train', + tokenizer=tokenizer, + max_length=classifier_config.model.length, + rc_aug=False, + add_special_tokens=classifier_config.data.add_special_tokens, + dataset=disc_dataset_hf['train'] + ) + + disc_dataset_val = custom_datasets.ten_species_dataset.TenSpeciesDataset( + split='test', + tokenizer=tokenizer, + max_length=classifier_config.model.length, + rc_aug=False, + add_special_tokens=classifier_config.data.add_special_tokens, + dataset=disc_dataset_hf['test'] + ) + + disc_train_dl = torch.utils.data.DataLoader( + disc_dataset_train, + batch_size=batch_size, + num_workers=0, + pin_memory=True, + shuffle=True) + + disc_val_dl = torch.utils.data.DataLoader( + disc_dataset_val, + batch_size=batch_size, + num_workers=0, + pin_memory=True, + shuffle=False) + + hyena_config = transformers.AutoConfig.from_pretrained( + 'LongSafari/hyenadna-small-32k-seqlen-hf', + d_model=d_model, + n_layer=n_layer, + trust_remote_code=True) + disc_model = transformers.AutoModelForSequenceClassification.from_config( + hyena_config, + pretrained=False, + num_labels=2, + problem_type='single_label_classification', + trust_remote_code=True) + + optimizer = torch.optim.AdamW( + disc_model.parameters(), lr=lr, weight_decay=0, + betas=(0.9, 0.999), eps=1e-8) + + disc_model.to('cuda') + losses = [] + auroc_list = [] + for ep in tqdm(range(epochs), desc='Epochs'): + # Train loop: + disc_model.train() + train_pbar = tqdm(disc_train_dl, desc='Train', + leave=False) + for batch in train_pbar: + labels = batch['species_label'].to('cuda') + logits = disc_model( + batch['input_ids'].to('cuda')).logits + loss = torch.nn.functional.cross_entropy( + logits.view(-1, logits.size(-1)), + labels, + ignore_index=-100, + reduction='mean') + optimizer.zero_grad() + loss.backward() + optimizer.step() + train_pbar.set_postfix({'loss': loss.item()}) + losses.append(loss.item()) + # Val loop: + disc_model.eval() + disc_labels = [] + disc_preds = [] + for batch in disc_val_dl: + disc_labels.append( + batch['species_label'].numpy()) + disc_preds.append( + disc_model( + batch['input_ids'].to('cuda') + ).logits[..., 1].detach().to('cpu').numpy() + ) + disc_labels = np.concatenate(disc_labels) + disc_preds = np.concatenate(disc_preds) + auroc = roc_auc_score(y_true=disc_labels, y_score=disc_preds) + auroc_list.append(auroc) + print(f"Ep {ep} - AUROC score {auroc}") + result_dict["Disc AUROC"] = auroc_list[-1] + del disc_model + print('*****************************') + return result_dict + + +@hydra.main(version_base=None, config_path='../configs', + config_name='config') +def main(config: omegaconf.DictConfig) -> None: + # Reproducibility + L.seed_everything(config.seed) + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.benchmark = False + + _print_config(config, resolve=True) + print(f"Checkpoint: {config.eval.checkpoint_path}") + + tokenizer = dataloader.get_tokenizer(config) + pretrained = diffusion.Diffusion.load_from_checkpoint( + config.eval.checkpoint_path, + tokenizer=tokenizer, + config=config, logger=False) + pretrained.eval() + + # Generate samples + if not os.path.exists(config.eval.generated_samples_path): + samples_per_class = {} + classes = range(config.data.num_classes) + for species in classes: + config.guidance.condition = species + print("Guidance:", ", ".join([f"{k.capitalize()} - {v}" for k, v in config.guidance.items()])) + samples = [] + for _ in tqdm( + range(config.sampling.num_sample_batches), desc='Gen. batches', leave=False): + sample = pretrained.sample() + samples.extend(pretrained.tokenizer.batch_decode(sample)) + samples_per_class[species] = samples + with open(config.eval.generated_samples_path, 'w') as f: + json.dump(samples_per_class, f, indent=4) # type: ignore + else: + with open(config.eval.generated_samples_path, 'r') as f: + samples_per_class = json.load(f) + samples_per_class = {int(k): v for k, v in samples_per_class.items()} + + # Run eval pipeline + hydra.core.global_hydra.GlobalHydra.instance().clear() + result_dict = run_eval_pipeline( + samples_per_class, + num_samples_per_class=config.sampling.num_sample_batches*config.sampling.batch_size, + train_weights_path=config.eval.train_weights_path, + val_weights_path=config.eval.val_weights_path, + eval_classifier_checkpoint_path=config.eval.eval_classifier_checkpoint_path, + kmer_freqs_path=config.eval.kmer_freqs_path) + result_dict['Seed'] = config.seed + result_dict['T'] = config.sampling.steps + result_dict = result_dict | {k.capitalize(): v for k, v in config.guidance.items()} + result_dict['Num Samples'] = sum([len(v) for v in samples_per_class.values()]) + results_df = pd.DataFrame.from_records([result_dict]) + results_df.to_csv(config.eval.results_csv_path) + +if __name__ == '__main__': + main() diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..106385263fe5c225ad92a55d0aaf22a0f47d1bea --- /dev/null +++ b/main.py @@ -0,0 +1,262 @@ +import json +import os + +import fsspec +import hydra +import lightning as L +import omegaconf +import rich.syntax +import rich.tree +import torch +from tqdm import tqdm +from datasets import load_from_disk +import pdb + +import classifier +import dataloader +import diffusion +import eval_utils +import utils + +omegaconf.OmegaConf.register_new_resolver( + 'cwd', os.getcwd) +omegaconf.OmegaConf.register_new_resolver( + 'device_count', torch.cuda.device_count) +omegaconf.OmegaConf.register_new_resolver( + 'eval', eval) +omegaconf.OmegaConf.register_new_resolver( + 'div_up', lambda x, y: (x + y - 1) // y) +omegaconf.OmegaConf.register_new_resolver( + 'if_then_else', + lambda condition, x, y: x if condition else y +) + + +def _load_from_checkpoint(config, tokenizer): + if 'hf' in config.backbone: + return diffusion.Diffusion( + config, tokenizer=tokenizer).to('cuda') + + return diffusion.Diffusion.load_from_checkpoint( + config.eval.checkpoint_path, + tokenizer=tokenizer, + config=config, logger=False).to('cuda') + + +@L.pytorch.utilities.rank_zero_only +def _print_config( + config: omegaconf.DictConfig, + resolve: bool = True, + save_cfg: bool = True) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + + Args: + config (DictConfig): Configuration composed by Hydra. + resolve (bool): Whether to resolve reference fields of DictConfig. + save_cfg (bool): Whether to save the configuration tree to a file. + """ + + style = 'dim' + tree = rich.tree.Tree('CONFIG', style=style, guide_style=style) + + fields = config.keys() + for field in fields: + branch = tree.add(field, style=style, guide_style=style) + + config_section = config.get(field) + branch_content = str(config_section) + if isinstance(config_section, omegaconf.DictConfig): + branch_content = omegaconf.OmegaConf.to_yaml( + config_section, resolve=resolve) + + branch.add(rich.syntax.Syntax(branch_content, 'yaml')) + rich.print(tree) + if save_cfg: + with fsspec.open( + '{}/config_tree.txt'.format( + config.checkpointing.save_dir), 'w') as fp: + rich.print(tree, file=fp) + + +@L.pytorch.utilities.rank_zero_only +def _print_batch(train_ds, valid_ds, tokenizer, k=64): + for dl_type, dl in [ + ('train', train_ds), ('valid', valid_ds)]: + print(f'Printing {dl_type} dataloader batch.') + batch = next(iter(dl)) + print('Batch input_ids.shape', batch['input_ids'].shape) + first = batch['input_ids'][0, :k] + last = batch['input_ids'][0, -k:] + print(f'First {k} tokens:', tokenizer.decode(first)) + print('ids:', first) + print(f'Last {k} tokens:', tokenizer.decode(last)) + print('ids:', last) + + +def _train(config, logger, tokenizer, + train_classifier=False): + logger.info('Starting Training.') + wandb_logger = None + if config.get('wandb', None) is not None: + wandb_logger = L.pytorch.loggers.WandbLogger( + config=omegaconf.OmegaConf.to_object(config), + ** config.wandb) + + if (config.checkpointing.resume_from_ckpt + and config.checkpointing.resume_ckpt_path is not None + and utils.fsspec_exists( + config.checkpointing.resume_ckpt_path)): + ckpt_path = config.checkpointing.resume_ckpt_path + else: + ckpt_path = None + + # Lightning callbacks + callbacks = [] + if 'callbacks' in config: + for _, callback in config.callbacks.items(): + callbacks.append(hydra.utils.instantiate(callback)) + + # train_ds, valid_ds = dataloader.get_dataloaders( + # config, tokenizer) + train_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/train') + val_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/val') + test_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/test') + + data_module = dataloader.CustomDataModule(train_dataset, val_dataset, test_dataset, tokenizer, config, batch_size=config.loader.batch_size) + train_ds = data_module.train_dataloader() + valid_ds = data_module.val_dataloader() + + if not config.is_vision: + _print_batch(train_ds, valid_ds, tokenizer) + + if train_classifier: + # This param indicates classifier will be used for + # PPLM / NOS-style guidance + # (see: https://arxiv.org/abs/2305.20009). + if getattr(config, 'is_pplm_classifier', False): + pretrained_model = _load_from_checkpoint( + config, tokenizer) + if (getattr(config.classifier_model, 'use_encoder_ema', True) + and pretrained_model.ema): + pretrained_model.load_ema_params() + pretrained_backbone = pretrained_model.backbone + # Remove the last layer for the classifier + if hasattr(pretrained_backbone, 'output_layer'): #DiT + delattr(pretrained_backbone, 'output_layer') + if hasattr(pretrained_backbone, 'model.lm_head'): #DiMamba + delattr(pretrained_backbone, 'model.lm_head') + if getattr(config.classifier_model, 'freeze_encoder', True): + for param in pretrained_backbone.parameters(): + param.requires_grad = False + else: + pretrained_backbone = None + + model = classifier.Classifier( + config, + tokenizer=valid_ds.tokenizer, + pretrained_backbone=pretrained_backbone) + else: + model = diffusion.Diffusion( + config, tokenizer=tokenizer) + # model = diffusion.Diffusion( + # config, tokenizer=valid_ds.tokenizer) + + trainer = hydra.utils.instantiate( + config.trainer, + default_root_dir=os.getcwd(), + callbacks=callbacks, + strategy=hydra.utils.instantiate(config.strategy), + logger=wandb_logger) + trainer.fit(model, train_ds, valid_ds, ckpt_path=ckpt_path) + + +def _gen_ppl_eval(config, tokenizer): + pretrained = _load_from_checkpoint( + config=config, tokenizer=tokenizer) + pretrained.eval() + samples = [] + for _ in tqdm(range(config.sampling.num_sample_batches), + desc='Gen. batches', leave=False): + sample = pretrained.sample() + samples.extend( + pretrained.tokenizer.batch_decode(sample)) + + # Replace CLS token with BOS token (if applicable) and + # remove padding and mask tokens + tok_bos_token = tokenizer.bos_token if tokenizer.bos_token is not None else tokenizer.cls_token + samples = [ + s.replace('[PAD]', '').replace('[MASK]', '').strip() + for s in samples + ] + # Add BOS token to the beginning of each sample (if not already present) + samples = [ + s if s.startswith(tok_bos_token) else f"{tok_bos_token} {s}" + for s in samples + ] + del pretrained # free up space for eval + print(f"Generated {len(samples)} samples.") + + generative_ppl = eval_utils.compute_generative_ppl( + samples, + eval_model_name_or_path=config.eval.generative_ppl_model_name_or_path, + gen_ppl_eval_batch_size=8, + max_length=config.model.length) + tokens = tokenizer.batch_encode_plus( + samples, + return_tensors='pt', + add_special_tokens=False, + max_length=config.model.length, + padding='max_length', + truncation=True)['input_ids'] + _, counts = torch.unique( + torch.tensor(tokens), return_counts=True, sorted=False) + entropy = torch.special.entr( + counts.float() / counts.sum()).sum().item() + with open(config.eval.generated_samples_path, 'w') as f: + json.dump({ + 'generative_ppl': generative_ppl, + 'entropy': entropy, + 'generated_seqs': samples, + }, + f, indent=4) # type: ignore + print(f"Entropy: {entropy:0.3f}") + print(f"Gen. PPL: {generative_ppl:0.3f}") + + +def _ppl_eval(config, tokenizer): + print(f"Evaluating perplexity on {config.data.valid}.") + pretrained = _load_from_checkpoint( + config=config, tokenizer=tokenizer) + pretrained.eval() + if not config.eval.disable_ema: + pretrained.load_ema_params() + + _, valid_ds = dataloader.get_dataloaders( + config, tokenizer, skip_train=True, valid_seed=config.seed) + ppl = eval_utils.compute_ppl(pretrained, valid_ds) + print(f"PPL: {ppl:0.3f}") + + +@hydra.main(version_base=None, config_path='configs', + config_name='config') +def main(config): + """Main entry point for training.""" + L.seed_everything(config.seed) + _print_config(config, resolve=True, save_cfg=True) + + logger = utils.get_logger(__name__) + tokenizer = dataloader.get_tokenizer(config) + + if config.mode == 'gen_ppl_eval': + _gen_ppl_eval(config, tokenizer) + elif config.mode == 'ppl_eval': + _ppl_eval(config, tokenizer) + elif 'train' in config.mode: + _train(config, logger, tokenizer, + train_classifier='classifier' in config.mode) + else: + raise NotImplementedError(f"Mode {config.mode} not implemented.") + + +if __name__ == '__main__': + main() diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..63499c089932632f178ef42f0a4d23dc50640449 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,4 @@ +from . import dit +from . import dimamba +from . import ema +from . import unet \ No newline at end of file diff --git a/models/__pycache__/__init__.cpython-310.pyc b/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..819266d9a0abd7d8ae13082e072629f2d2ff53b2 Binary files /dev/null and b/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/models/__pycache__/__init__.cpython-39.pyc b/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88e7ec9c7b60909bc5e87ac09720a4dd5191c338 Binary files /dev/null and b/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/models/__pycache__/bindevaluator.cpython-310.pyc b/models/__pycache__/bindevaluator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a377aa57bc26c1730fffd5f5641a49cbecfa7b68 Binary files /dev/null and b/models/__pycache__/bindevaluator.cpython-310.pyc differ diff --git a/models/__pycache__/dimamba.cpython-310.pyc b/models/__pycache__/dimamba.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75e71308684da1befe8aeb5cac5a8b8cca26817d Binary files /dev/null and b/models/__pycache__/dimamba.cpython-310.pyc differ diff --git a/models/__pycache__/dimamba.cpython-39.pyc b/models/__pycache__/dimamba.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76aef9dd47310cf7f4de09c36069f769efb1f564 Binary files /dev/null and b/models/__pycache__/dimamba.cpython-39.pyc differ diff --git a/models/__pycache__/dit.cpython-310.pyc b/models/__pycache__/dit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad5fd525963329fe8070520a61358c908cc1741a Binary files /dev/null and b/models/__pycache__/dit.cpython-310.pyc differ diff --git a/models/__pycache__/dit.cpython-39.pyc b/models/__pycache__/dit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79b1982f6d198036d148debb0d6e59ee3d258fc2 Binary files /dev/null and b/models/__pycache__/dit.cpython-39.pyc differ diff --git a/models/__pycache__/ema.cpython-310.pyc b/models/__pycache__/ema.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49f123b99213e7c9c917a3f48c2248df2f65a8a0 Binary files /dev/null and b/models/__pycache__/ema.cpython-310.pyc differ diff --git a/models/__pycache__/ema.cpython-39.pyc b/models/__pycache__/ema.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08187ac36f4ae405cf04758b15a853c7977e2b55 Binary files /dev/null and b/models/__pycache__/ema.cpython-39.pyc differ diff --git a/models/__pycache__/unet.cpython-310.pyc b/models/__pycache__/unet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a8070c142d0ea4091d9c98aa6c60ef30b5e7085 Binary files /dev/null and b/models/__pycache__/unet.cpython-310.pyc differ diff --git a/models/__pycache__/unet.cpython-39.pyc b/models/__pycache__/unet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f279992d351d5651025fa6f168c1b9a41be68a49 Binary files /dev/null and b/models/__pycache__/unet.cpython-39.pyc differ diff --git a/models/bindevaluator.py b/models/bindevaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..15d816235d066441a4c883c48ab76ecbfee72d56 --- /dev/null +++ b/models/bindevaluator.py @@ -0,0 +1,78 @@ +import pdb +import torch +import torch.nn.functional as F +import pytorch_lightning as pl +import time + +from .bindevaluator_modules import * + + +class BindEvaluator(pl.LightningModule): + def __init__(self, n_layers, d_model, d_hidden, n_head, + d_k, d_v, d_inner, dropout=0.2, + learning_rate=0.00001, max_epochs=15, kl_weight=1): + super(BindEvaluator, self).__init__() + + self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") + # freeze all the esm_model parameters + for param in self.esm_model.parameters(): + param.requires_grad = False + + self.repeated_module = RepeatedModule3(n_layers, d_model, d_hidden, + n_head, d_k, d_v, d_inner, dropout=dropout) + + self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model, + d_k, d_v, dropout=dropout) + + self.final_ffn = FFN(d_model, d_inner, dropout=dropout) + + self.output_projection_prot = nn.Linear(d_model, 1) + + self.learning_rate = learning_rate + self.max_epochs = max_epochs + self.kl_weight = kl_weight + + self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold + self.historical_memory = 0.9 + self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925]) # binding_site weights, non-bidning site weights + + def forward(self, binder_tokens, target_tokens): + peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state + protein_sequence = self.esm_model(**target_tokens).last_hidden_state + + prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ + seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence, + protein_sequence) + + prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc) + + prot_enc = self.final_ffn(prot_enc) + + prot_enc = self.output_projection_prot(prot_enc) + + return prot_enc + + def get_probs(self, xt, target_sequence): + ''' + Inputs: + - xt: Shape (bsz*seq_len*vocab_size, seq_len) + - target_sequence: Shape (bsz*seq_len*vocab_size, tgt_len) + ''' + binder_attention_mask = torch.ones_like(xt) + target_attention_mask = torch.ones_like(target_sequence) + + binder_attention_mask[:, 0] = binder_attention_mask[:, -1] = 0 + target_attention_mask[:, 0] = target_attention_mask[:, -1] = 0 + + binder_tokens = {'input_ids': xt, 'attention_mask': binder_attention_mask.to(xt.device)} + target_tokens = {'input_ids': target_sequence, 'attention_mask': target_attention_mask.to(target_sequence.device)} + + + start = time.time() + logits = self.forward(binder_tokens, target_tokens).squeeze(-1) + # print(f"Time: {time.time() - start} seconds") + + logits[:, 0] = logits[:, -1] = -100 # float('-inf') + log_probs = F.softmax(logits, dim=-1) + + return log_probs # shape (bsz*seq_len*vocab_size, tgt_len) diff --git a/models/bindevaluator_modules/__init__.py b/models/bindevaluator_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0175e72e9d2cbada6c5e3d8f10c57a34192a5bb0 --- /dev/null +++ b/models/bindevaluator_modules/__init__.py @@ -0,0 +1,3 @@ +from .models import * +from .score_domain import * +from .dataloaders import * diff --git a/models/bindevaluator_modules/__pycache__/__init__.cpython-310.pyc b/models/bindevaluator_modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75817b8d7229ae28c8954483820e95ae1286275b Binary files /dev/null and b/models/bindevaluator_modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/models/bindevaluator_modules/__pycache__/__init__.cpython-39.pyc b/models/bindevaluator_modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae605241916cc47642e78e4e36a2d0edc2415b64 Binary files /dev/null and b/models/bindevaluator_modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/models/bindevaluator_modules/__pycache__/dataloaders.cpython-310.pyc b/models/bindevaluator_modules/__pycache__/dataloaders.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0739ebfdf277fdd6c2616feeb32c67ea880f8607 Binary files /dev/null and b/models/bindevaluator_modules/__pycache__/dataloaders.cpython-310.pyc differ diff --git a/models/bindevaluator_modules/__pycache__/dataloaders.cpython-39.pyc b/models/bindevaluator_modules/__pycache__/dataloaders.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14d83842d314f8a7af7ed62e03e475761f5d0174 Binary files /dev/null and b/models/bindevaluator_modules/__pycache__/dataloaders.cpython-39.pyc differ diff --git a/models/bindevaluator_modules/__pycache__/layers.cpython-310.pyc b/models/bindevaluator_modules/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4d78ce515b43eaa7aa8da01d5ca0edbdddffc08 Binary files /dev/null and b/models/bindevaluator_modules/__pycache__/layers.cpython-310.pyc differ diff --git a/models/bindevaluator_modules/__pycache__/layers.cpython-39.pyc b/models/bindevaluator_modules/__pycache__/layers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bab721133af54c09f75b30cbaa82bcb3ecbe4aa2 Binary files /dev/null and b/models/bindevaluator_modules/__pycache__/layers.cpython-39.pyc differ diff --git a/models/bindevaluator_modules/__pycache__/models.cpython-310.pyc b/models/bindevaluator_modules/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..045c3d6c7379db405b5d2d57aceb8a93781b68e4 Binary files /dev/null and b/models/bindevaluator_modules/__pycache__/models.cpython-310.pyc differ diff --git a/models/bindevaluator_modules/__pycache__/models.cpython-39.pyc b/models/bindevaluator_modules/__pycache__/models.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07f5ba26ae446544300e7e423e2f5a6557075bf1 Binary files /dev/null and b/models/bindevaluator_modules/__pycache__/models.cpython-39.pyc differ diff --git a/models/bindevaluator_modules/__pycache__/modules.cpython-310.pyc b/models/bindevaluator_modules/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1772f8f9c6630de694aad1cfdec902fd96c3a26 Binary files /dev/null and b/models/bindevaluator_modules/__pycache__/modules.cpython-310.pyc differ diff --git a/models/bindevaluator_modules/__pycache__/modules.cpython-39.pyc b/models/bindevaluator_modules/__pycache__/modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5735957933940a869bc5c518a0e250397855fd4 Binary files /dev/null and b/models/bindevaluator_modules/__pycache__/modules.cpython-39.pyc differ diff --git a/models/bindevaluator_modules/__pycache__/score_domain.cpython-310.pyc b/models/bindevaluator_modules/__pycache__/score_domain.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..917728e5728e5628576280e9242706d50efa1d1d Binary files /dev/null and b/models/bindevaluator_modules/__pycache__/score_domain.cpython-310.pyc differ diff --git a/models/bindevaluator_modules/__pycache__/score_domain.cpython-39.pyc b/models/bindevaluator_modules/__pycache__/score_domain.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f8750506af87eea4e376ceaba96e563b610283d Binary files /dev/null and b/models/bindevaluator_modules/__pycache__/score_domain.cpython-39.pyc differ diff --git a/models/bindevaluator_modules/dataloaders.py b/models/bindevaluator_modules/dataloaders.py new file mode 100644 index 0000000000000000000000000000000000000000..b6d9deaf0ab33682d781b94af70f54d8d0883c8e --- /dev/null +++ b/models/bindevaluator_modules/dataloaders.py @@ -0,0 +1,426 @@ +# -*- coding: utf-8 -*- +""" +Created on Sat Jul 31 21:54:08 2021 + +@author: Osama +""" + +from torch.utils.data import Dataset +from Bio.PDB import Polypeptide +import numpy as np +import torch +import pandas as pd +import os +# import esm +import ast +import pdb + + +class InterpepComplexes(Dataset): + + def __init__(self, mode, + encoded_data_directory = "../../datasets/interpep_data/"): + + self.mode = mode + + self.encoded_data_directory = encoded_data_directory + + self.train_dir = "../../datasets/interpep_data/train_examples.npy" + + self.test_dir = "../../datasets/interpep_data/test_examples.npy" + + self.val_dir = "../../datasets/interpep_data/val_examples.npy" + + + self.test_list = np.load(self.test_dir) + + self.train_list = np.load(self.train_dir) + + self.val_list = np.load(self.val_dir) + + + + if mode == "train": + self.num_data = len(self.train_list) + elif mode == "val": + self.num_data = len(self.val_list) + elif mode == "test": + self.num_data = len(self.test_list) + + + + def __getitem__(self, index): + + if self.mode == "train": + item = self.train_list[index] + elif self.mode == "val": + item = self.val_list[index] + elif self.mode == "test": + item = self.test_list[index] + + file_dir = self.encoded_data_directory + + with np.load(file_dir + "fragment_data/" + item + ".npz") as data: + temp_pep_sequence = data["target_sequence"] + temp_binding_sites = data["binding_sites"] + + + with np.load(file_dir + "receptor_data/" + item.split("_")[0] + "_" +\ + item.split("_")[1] + ".npz") as data: + temp_nodes = data["nodes"] + + + binding = np.zeros(len(temp_nodes)) + if len(temp_binding_sites) != 0: + binding[temp_binding_sites] = 1 + target = torch.LongTensor(binding) + + + + + + + + nodes = temp_nodes[:, 0:20] + + prot_sequence = np.argmax(nodes, axis=-1) + + + + prot_sequence = " ".join([Polypeptide.index_to_one(i) for i in prot_sequence]) + + + + pep_sequence = temp_pep_sequence + + pep_sequence = torch.argmax(torch.FloatTensor(pep_sequence), dim=-1) + + + + + + return pep_sequence, prot_sequence, target + + def __len__(self): + return self.num_data + +class PPI(Dataset): + + def __init__(self, mode, csv_dir_path = "/home/u21307130002/PepNN/pepnn/datasets/ppi/"): + + self.mode = mode + self.train_data = pd.read_csv(os.path.join(csv_dir_path, 'train.csv')) + self.val_data = pd.read_csv(os.path.join(csv_dir_path, 'val.csv')) + # self.test_data = pd.read_csv(os.path.join(csv_dir_path, 'test.csv')) + + if self.mode == 'train': + self.num_data = len(self.train_data) + + def __len__(self): + return self.num_data + + def __getitem__(self, index): + # pdb.set_trace() + if torch.is_tensor(index): + index = index.tolist() + + if self.mode == "train": + item = self.train_data.iloc[index] + elif self.mode == "val": + item = self.val_data.iloc[index] + elif self.mode == "test": + item = self.test_data.iloc[index] + else: + item = None + + # print(item) + + motif1 = ast.literal_eval(item['Chain_1_motifs']) + motif2 = ast.literal_eval(item['Chain_2_motifs']) + + if len(motif1[0]) > len(motif2[0]): + target = motif1 + prot_sequence = item['Sequence1'] + pep_sequence = item['Sequence2'] + else: + target = motif2 + pep_sequence = item['Sequence1'] + prot_sequence = item['Sequence2'] + + target = [int(motif.split('_')[1]) for motif in target] + + if target[-1] >= len(prot_sequence): + pdb.set_trace() + + binding = np.zeros(len(prot_sequence)) + if len(target) != 0: + binding[target] = 1 + target = torch.LongTensor(binding).float() + + # print(f"peptide length: {len(pep_sequence)}") + # print(f"protein length: {len(prot_sequence)}") + # print(f"target length: {len(target)}") + # pdb.set_trace() + + return pep_sequence, prot_sequence, target + + + + +class PepBindComplexes(Dataset): + + def __init__(self, mode, + encoded_data_directory = "../../datasets/pepbind_data/"): + + self.mode = mode + + self.encoded_data_directory = encoded_data_directory + + self.train_dir = "../../datasets/pepbind_data/train_examples.npy" + + self.test_dir = "../../datasets/pepbind_data/test_examples.npy" + + self.val_dir = "../../datasets/pepbind_data/val_examples.npy" + + + self.test_list = np.load(self.test_dir) + + self.train_list = np.load(self.train_dir) + + self.val_list = np.load(self.val_dir) + + + if mode == "train": + self.num_data = len(self.train_list) + elif mode == "val": + self.num_data = len(self.val_list) + elif mode == "test": + self.num_data = len(self.test_list) + + + + def __getitem__(self, index): + + if self.mode == "train": + item = self.train_list[index] + + + elif self.mode == "val": + item = self.val_list[index] + + + elif self.mode == "test": + item = self.test_list[index] + + + + file_dir = self.encoded_data_directory + + + with np.load(file_dir + "fragment_data/" + item + ".npz") as data: + temp_pep_sequence = data["target_sequence"] + temp_binding_sites = data["binding_sites"] + + + with np.load(file_dir + "receptor_data/" + item.split("_")[0] + "_" +\ + item.split("_")[1] + ".npz") as data: + temp_nodes = data["nodes"] + + + binding = np.zeros(len(temp_nodes)) + if len(temp_binding_sites) != 0: + binding[temp_binding_sites] = 1 + target = torch.LongTensor(binding) + + nodes = temp_nodes[:, 0:20] + + prot_sequence = np.argmax(nodes, axis=-1) + + + prot_sequence = " ".join([Polypeptide.index_to_one(i) for i in prot_sequence]) + + + pep_sequence = temp_pep_sequence + + pep_sequence = torch.argmax(torch.FloatTensor(pep_sequence), dim=-1) + + + return pep_sequence, prot_sequence, target + + + def __len__(self): + return self.num_data + +class PeptideComplexes(Dataset): + + def __init__(self, mode, + encoded_data_directory = "../../datasets/pepnn_data/all_data/"): + + self.mode = mode + + self.encoded_data_directory = encoded_data_directory + + self.train_dir = "../../datasets/pepnn_data/train_examples.npy" + + self.test_dir = "../../datasets/pepnn_test_data/test_examples.npy" + + self.val_dir = "../../datasets/pepnn_data/val_examples.npy" + + + self.example_weights = np.load("../../datasets/pepnn_data/example_weights.npy") + + self.test_list = np.load(self.test_dir) + + self.train_list = np.load(self.train_dir) + + self.val_list = np.load(self.val_dir) + + + + if mode == "train": + self.num_data = len(self.train_list) + elif mode == "val": + self.num_data = len(self.val_list) + elif mode == "test": + self.num_data = len(self.test_list) + + + + def __getitem__(self, index): + + + if self.mode == "train": + item = self.train_list[index] + + weight = self.example_weights[item] + + elif self.mode == "val": + item = self.val_list[index] + + weight = self.example_weights[item] + + elif self.mode == "test": + item = self.test_list[index] + + weight = 1 + + if self.mode != "test": + file_dir = self.encoded_data_directory + else: + file_dir = "../../datasets/pepnn_test_data/all_data/" + + + with np.load(file_dir + "fragment_data/" + item + ".npz") as data: + temp_pep_sequence = data["target_sequence"] + temp_binding_sites = data["binding_sites"] + + + with np.load(file_dir + "receptor_data/" + item.split("_")[0] + "_" +\ + item.split("_")[1] + ".npz") as data: + temp_nodes = data["nodes"] + + + binding = np.zeros(len(temp_nodes)) + if len(temp_binding_sites) != 0: + binding[temp_binding_sites] = 1 + target = torch.LongTensor(binding) + + + + + + + + nodes = temp_nodes[:, 0:20] + + prot_sequence = np.argmax(nodes, axis=-1) + + + + prot_sequence = " ".join([Polypeptide.index_to_one(i) for i in prot_sequence]) + + + + pep_sequence = temp_pep_sequence + + pep_sequence = torch.argmax(torch.FloatTensor(pep_sequence), dim=-1) + + + + + + return pep_sequence, prot_sequence, target, weight + + + def __len__(self): + return self.num_data + + +class BitenetComplexes(Dataset): + + def __init__(self, encoded_data_directory = "../bitenet_data/all_data/"): + + + self.encoded_data_directory = encoded_data_directory + + + + + self.train_dir = "../../datasets/bitenet_data/examples.npy" + + + + + self.full_list = np.load(self.train_dir) + + + + + self.num_data = len(self.full_list) + + + + + def __getitem__(self, index): + + item = self.full_list[index] + + file_dir = self.encoded_data_directory + + with np.load(file_dir + "fragment_data/" + item[:-1] + "_" + item[-1] + ".npz") as data: + temp_pep_sequence = data["target_sequence"] + temp_binding_matrix = data["binding_matrix"] + + + with np.load(file_dir + "receptor_data/" + item.split("_")[0] + "_" +\ + item.split("_")[1][0] + ".npz") as data: + temp_nodes = data["nodes"] + + + binding_sum = np.sum(temp_binding_matrix, axis=0).T + + target = torch.LongTensor(binding_sum >= 1) + + + + nodes = temp_nodes[:, 0:20] + + prot_sequence = np.argmax(nodes, axis=-1) + + + + prot_sequence = " ".join([Polypeptide.index_to_one(i) for i in prot_sequence]) + + + + pep_sequence = temp_pep_sequence + + pep_sequence = torch.argmax(torch.FloatTensor(pep_sequence), dim=-1) + + + + + return pep_sequence, prot_sequence, target + + def __len__(self): + return self.num_data \ No newline at end of file diff --git a/models/bindevaluator_modules/layers.py b/models/bindevaluator_modules/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..179fa1ce31ecb11506e06430a7a70adc159e9565 --- /dev/null +++ b/models/bindevaluator_modules/layers.py @@ -0,0 +1,142 @@ +from torch import nn +from .modules import * +import pdb + +class ConvLayer(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, padding, dilation): + super(ConvLayer, self).__init__() + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation) + self.relu = nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + + +class DilatedCNN(nn.Module): + def __init__(self, d_model, d_hidden): + super(DilatedCNN, self).__init__() + self.first_ = nn.ModuleList() + self.second_ = nn.ModuleList() + self.third_ = nn.ModuleList() + + dilation_tuple = (1, 2, 3) + dim_in_tuple = (d_model, d_hidden, d_hidden) + dim_out_tuple = (d_hidden, d_hidden, d_hidden) + + for i, dilation_rate in enumerate(dilation_tuple): + self.first_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=3, padding=dilation_rate, + dilation=dilation_rate)) + + for i, dilation_rate in enumerate(dilation_tuple): + self.second_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=5, padding=2*dilation_rate, + dilation=dilation_rate)) + + for i, dilation_rate in enumerate(dilation_tuple): + self.third_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=7, padding=3*dilation_rate, + dilation=dilation_rate)) + + def forward(self, protein_seq_enc): + # pdb.set_trace() + protein_seq_enc = protein_seq_enc.transpose(1, 2) # protein_seq_enc's shape: B*L*d_model -> B*d_model*L + + first_embedding = protein_seq_enc + second_embedding = protein_seq_enc + third_embedding = protein_seq_enc + + for i in range(len(self.first_)): + first_embedding = self.first_[i](first_embedding) + + for i in range(len(self.second_)): + second_embedding = self.second_[i](second_embedding) + + for i in range(len(self.third_)): + third_embedding = self.third_[i](third_embedding) + + # pdb.set_trace() + + protein_seq_enc = first_embedding + second_embedding + third_embedding + + return protein_seq_enc.transpose(1, 2) + + +class ReciprocalLayerwithCNN(nn.Module): + + def __init__(self, d_model, d_inner, d_hidden, n_head, d_k, d_v): + super().__init__() + + self.cnn = DilatedCNN(d_model, d_hidden) + + self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, + d_k, d_v) + + self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, + d_k, d_v) + + self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_hidden, + d_k, d_v) + + self.ffn_seq = FFN(d_hidden, d_inner) + + self.ffn_protein = FFN(d_hidden, d_inner) + + def forward(self, sequence_enc, protein_seq_enc): + # pdb.set_trace() # protein_seq_enc.shape = B * L * d_model + protein_seq_enc = self.cnn(protein_seq_enc) + prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc) + + seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc) + + prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, + seq_enc, + seq_enc, + prot_enc) + prot_enc = self.ffn_protein(prot_enc) + + seq_enc = self.ffn_seq(seq_enc) + + return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention + + +class ReciprocalLayer(nn.Module): + + def __init__(self, d_model, d_inner, n_head, d_k, d_v): + + super().__init__() + + self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_model, + d_k, d_v) + + self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_model, + d_k, d_v) + + self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_model, + d_k, d_v) + + + + self.ffn_seq = FFN(d_model, d_inner) + + self.ffn_protein = FFN(d_model, d_inner) + + def forward(self, sequence_enc, protein_seq_enc): + prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc) + + seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc) + + + prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, + seq_enc, + seq_enc, + prot_enc) + prot_enc = self.ffn_protein(prot_enc) + + seq_enc = self.ffn_seq(seq_enc) + + + + return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention + + + diff --git a/models/bindevaluator_modules/models.py b/models/bindevaluator_modules/models.py new file mode 100644 index 0000000000000000000000000000000000000000..8f0a5f672a6767de6ab6ba0b6f759a56167a9a78 --- /dev/null +++ b/models/bindevaluator_modules/models.py @@ -0,0 +1,284 @@ +import pdb + +import numpy as np +import torch +import torch.nn as nn +from .layers import * +from .modules import * +import pdb +from transformers import EsmModel, EsmTokenizer + +def to_var(x): + if torch.cuda.is_available(): + x = x.cuda() + return x + + +class RepeatedModule3(nn.Module): + def __init__(self, n_layers, d_model, d_hidden, + n_head, d_k, d_v, d_inner, dropout=0.1): + super().__init__() + + self.linear1 = nn.Linear(1280, d_model) + self.linear2 = nn.Linear(1280, d_model) + self.sequence_embedding = nn.Embedding(20, d_model) + self.d_model = d_model + + self.reciprocal_layer_stack = nn.ModuleList([ + ReciprocalLayerwithCNN(d_model, d_inner, d_hidden, n_head, d_k, d_v) + for _ in range(n_layers)]) + + self.dropout = nn.Dropout(dropout) + self.dropout_2 = nn.Dropout(dropout) + + def forward(self, peptide_sequence, protein_sequence): + sequence_attention_list = [] + + prot_attention_list = [] + + prot_seq_attention_list = [] + + seq_prot_attention_list = [] + + sequence_enc = self.dropout(self.linear1(peptide_sequence)) + + prot_enc = self.dropout_2(self.linear2(protein_sequence)) + + for reciprocal_layer in self.reciprocal_layer_stack: + prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention = \ + reciprocal_layer(sequence_enc, prot_enc) + + sequence_attention_list.append(sequence_attention) + + prot_attention_list.append(prot_attention) + + prot_seq_attention_list.append(prot_seq_attention) + + seq_prot_attention_list.append(seq_prot_attention) + + return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ + seq_prot_attention_list, seq_prot_attention_list + + +class RepeatedModule2(nn.Module): + def __init__(self, n_layers, d_model, + n_head, d_k, d_v, d_inner, dropout=0.1): + super().__init__() + + self.linear1 = nn.Linear(1280, d_model) + self.linear2 = nn.Linear(1280, d_model) + self.sequence_embedding = nn.Embedding(20, d_model) + self.d_model = d_model + + self.reciprocal_layer_stack = nn.ModuleList([ + ReciprocalLayer(d_model, d_inner, n_head, d_k, d_v) + for _ in range(n_layers)]) + + self.dropout = nn.Dropout(dropout) + self.dropout_2 = nn.Dropout(dropout) + + def forward(self, peptide_sequence, protein_sequence): + sequence_attention_list = [] + + prot_attention_list = [] + + prot_seq_attention_list = [] + + seq_prot_attention_list = [] + + sequence_enc = self.dropout(self.linear1(peptide_sequence)) + + prot_enc = self.dropout_2(self.linear2(protein_sequence)) + + for reciprocal_layer in self.reciprocal_layer_stack: + prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention = \ + reciprocal_layer(sequence_enc, prot_enc) + + sequence_attention_list.append(sequence_attention) + + prot_attention_list.append(prot_attention) + + prot_seq_attention_list.append(prot_seq_attention) + + seq_prot_attention_list.append(seq_prot_attention) + + return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ + seq_prot_attention_list, seq_prot_attention_list + + +class RepeatedModule(nn.Module): + + def __init__(self, n_layers, d_model, + n_head, d_k, d_v, d_inner, dropout=0.1): + + super().__init__() + + self.linear = nn.Linear(1024, d_model) + self.sequence_embedding = nn.Embedding(20, d_model) + self.d_model = d_model + + self.reciprocal_layer_stack = nn.ModuleList([ + ReciprocalLayer(d_model, d_inner, n_head, d_k, d_v) + for _ in range(n_layers)]) + + self.dropout = nn.Dropout(dropout) + self.dropout_2 = nn.Dropout(dropout) + + + + def _positional_embedding(self, batches, number): + + result = torch.exp(torch.arange(0, self.d_model,2,dtype=torch.float32)*-1*(np.log(10000)/self.d_model)) + + numbers = torch.arange(0, number, dtype=torch.float32) + + numbers = numbers.unsqueeze(0) + + numbers = numbers.unsqueeze(2) + + result = numbers*result + + result = torch.cat((torch.sin(result), torch.cos(result)),2) + + return result + + def forward(self, peptide_sequence, protein_sequence): + + + sequence_attention_list = [] + + prot_attention_list = [] + + prot_seq_attention_list = [] + + seq_prot_attention_list = [] + + sequence_enc = self.sequence_embedding(peptide_sequence) + + sequence_enc += to_var(self._positional_embedding(peptide_sequence.shape[0], + peptide_sequence.shape[1])) + sequence_enc = self.dropout(sequence_enc) + + + + + + prot_enc = self.dropout_2(self.linear(protein_sequence)) + + + + + for reciprocal_layer in self.reciprocal_layer_stack: + + prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention =\ + reciprocal_layer(sequence_enc, prot_enc) + + sequence_attention_list.append(sequence_attention) + + prot_attention_list.append(prot_attention) + + prot_seq_attention_list.append(prot_seq_attention) + + seq_prot_attention_list.append(seq_prot_attention) + + + + return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list,\ + seq_prot_attention_list, seq_prot_attention_list + + +class FullModel(nn.Module): + + def __init__(self, n_layers, d_model, n_head, + d_k, d_v, d_inner, return_attention=False, dropout=0.2): + super().__init__() + + self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") + + # freeze all the esm_model parameters + for param in self.esm_model.parameters(): + param.requires_grad = False + + self.repeated_module = RepeatedModule2(n_layers, d_model, + n_head, d_k, d_v, d_inner, dropout=dropout) + + self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model, + d_k, d_v, dropout=dropout) + + self.final_ffn = FFN(d_model, d_inner, dropout=dropout) + + self.output_projection_prot = nn.Linear(d_model, 1) + self.sigmoid = nn.Sigmoid() + + self.return_attention = return_attention + + def forward(self, binder_tokens, target_tokens): + + with torch.no_grad(): + peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state + protein_sequence = self.esm_model(**target_tokens).last_hidden_state + + # pdb.set_trace() + + prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ + seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence, + protein_sequence) + + prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc) + + # pdb.set_trace() + + prot_enc = self.final_ffn(prot_enc) + + prot_enc = self.sigmoid(self.output_projection_prot(prot_enc)) + + return prot_enc + + + +class Original_FullModel(nn.Module): + + def __init__(self, n_layers, d_model, n_head, + d_k, d_v, d_inner, return_attention=False, dropout=0.2): + + super().__init__() + self.repeated_module = RepeatedModule(n_layers, d_model, + n_head, d_k, d_v, d_inner, dropout=dropout) + + self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model, + d_k, d_v, dropout=dropout) + + self.final_ffn = FFN(d_model, d_inner, dropout=dropout) + self.output_projection_prot = nn.Linear(d_model, 2) + + + + self.softmax_prot =nn.LogSoftmax(dim=-1) + + + self.return_attention = return_attention + + def forward(self, peptide_sequence, protein_sequence): + + prot_enc, sequence_enc, sequence_attention_list, prot_attention_list,\ + seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence, + protein_sequence) + + + + prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc) + + prot_enc = self.final_ffn(prot_enc) + + prot_enc = self.softmax_prot(self.output_projection_prot(prot_enc)) + + + + + + if not self.return_attention: + return prot_enc + else: + return prot_enc, sequence_attention_list, prot_attention_list,\ + seq_prot_attention_list, seq_prot_attention_list + diff --git a/models/bindevaluator_modules/modules.py b/models/bindevaluator_modules/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..63a83d8f24e4aae9b8f52957920a5cd1b2755f1e --- /dev/null +++ b/models/bindevaluator_modules/modules.py @@ -0,0 +1,187 @@ +from torch import nn +import numpy as np +import torch +import torch.nn.functional as F + + +def to_var(x): + if torch.cuda.is_available(): + x = x.cuda() + return x + + + + + +class MultiHeadAttentionSequence(nn.Module): + + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): + + super().__init__() + + self.n_head = n_head + self.d_model = d_model + self.d_k = d_k + self.d_v = d_v + + self.W_Q = nn.Linear(d_model, n_head*d_k) + self.W_K = nn.Linear(d_model, n_head*d_k) + self.W_V = nn.Linear(d_model, n_head*d_v) + self.W_O = nn.Linear(n_head*d_v, d_model) + + self.layer_norm = nn.LayerNorm(d_model) + + self.dropout = nn.Dropout(dropout) + + def forward(self, q, k, v): + + batch, len_q, _ = q.size() + batch, len_k, _ = k.size() + batch, len_v, _ = v.size() + + Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k]) + K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k]) + V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v]) + + Q = Q.transpose(1, 2) + K = K.transpose(1, 2).transpose(2, 3) + V = V.transpose(1, 2) + + attention = torch.matmul(Q, K) + + attention = attention / np.sqrt(self.d_k) + + attention = F.softmax(attention, dim=-1) + + output = torch.matmul(attention, V) + + output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head]) + + output = self.W_O(output) + + output = self.dropout(output) + + output = self.layer_norm(output + q) + + return output, attention + +class MultiHeadAttentionReciprocal(nn.Module): + + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): + + super().__init__() + + self.n_head = n_head + self.d_model = d_model + self.d_k = d_k + self.d_v = d_v + + + self.W_Q = nn.Linear(d_model, n_head*d_k) + self.W_K = nn.Linear(d_model, n_head*d_k) + self.W_V = nn.Linear(d_model, n_head*d_v) + self.W_O = nn.Linear(n_head*d_v, d_model) + self.W_V_2 = nn.Linear(d_model, n_head*d_v) + self.W_O_2 = nn.Linear(n_head*d_v, d_model) + + self.layer_norm = nn.LayerNorm(d_model) + + self.dropout = nn.Dropout(dropout) + + self.layer_norm_2 = nn.LayerNorm(d_model) + + self.dropout_2 = nn.Dropout(dropout) + + + + + def forward(self, q, k, v, v_2): + + batch, len_q, _ = q.size() + batch, len_k, _ = k.size() + batch, len_v, _ = v.size() + batch, len_v_2, _ = v_2.size() + + + Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k]) + K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k]) + V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v]) + V_2 = self.W_V_2(v_2).view([batch, len_v_2, self.n_head, self.d_v]) + + + + Q = Q.transpose(1, 2) + K = K.transpose(1, 2).transpose(2, 3) + V = V.transpose(1, 2) + V_2 = V_2.transpose(1,2) + + attention = torch.matmul(Q, K) + + + attention = attention /np.sqrt(self.d_k) + + attention_2 = attention.transpose(-2, -1) + + + + attention = F.softmax(attention, dim=-1) + + attention_2 = F.softmax(attention_2, dim=-1) + + + output = torch.matmul(attention, V) + + output_2 = torch.matmul(attention_2, V_2) + + output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head]) + + output_2 = output_2.transpose(1, 2).reshape([batch, len_k, self.d_v*self.n_head]) + + output = self.W_O(output) + + output_2 = self.W_O_2(output_2) + + output = self.dropout(output) + + output = self.layer_norm(output + q) + + output_2 = self.dropout(output_2) + + output_2 = self.layer_norm(output_2 + k) + + + + + + return output, output_2, attention, attention_2 + + +class FFN(nn.Module): + + def __init__(self, d_in, d_hid, dropout=0.1): + super().__init__() + + self.layer_1 = nn.Conv1d(d_in, d_hid,1) + self.layer_2 = nn.Conv1d(d_hid, d_in,1) + self.relu = nn.ReLU() + self.layer_norm = nn.LayerNorm(d_in) + + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + + residual = x + output = self.layer_1(x.transpose(1, 2)) + + output = self.relu(output) + + output = self.layer_2(output) + + output = self.dropout(output) + + output = self.layer_norm(output.transpose(1, 2)+residual) + + return output + diff --git a/models/bindevaluator_modules/score_domain.py b/models/bindevaluator_modules/score_domain.py new file mode 100644 index 0000000000000000000000000000000000000000..17c56ec8b52dbed7cef4cebb1cbb5b92b091627e --- /dev/null +++ b/models/bindevaluator_modules/score_domain.py @@ -0,0 +1,40 @@ +from scipy.stats import norm +import numpy as np +import os + + +def score(outputs): + + weight = 0.03 + binding_size_dist = np.load(os.path.join(os.path.dirname(__file__), "../params/binding_size_train_dist.npy")) + + + mean = np.mean(binding_size_dist) + + std = np.std(binding_size_dist) + + dist = norm(mean, std) + + + max_score = 0 + + + + scores = np.exp(outputs[0])[:, 1] + + indices = np.argsort(-1*scores) + + for j in range(1, len(indices)): + + + + score = (1-weight)*np.mean(scores[indices[:j]]) + weight*(dist.pdf(j/len(indices))) + + + if score > max_score: + + max_score = score + + + return max_score + \ No newline at end of file diff --git a/models/dimamba.py b/models/dimamba.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb33b2941fb27ade9d9afe43ed00becb1474503 --- /dev/null +++ b/models/dimamba.py @@ -0,0 +1,1235 @@ +import math +from functools import partial +from typing import Optional, Tuple, Union + +import huggingface_hub +import omegaconf +import torch +import torch.nn as nn +import torch.nn.functional as F +from causal_conv1d import ( + causal_conv1d_fn, + causal_conv1d_update, +) +from einops import rearrange, repeat +from mamba_ssm.ops.selective_scan_interface import ( + mamba_inner_fn, + selective_scan_fn, +) +from torch import Tensor +from transformers import PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import ( + BaseModelOutputWithNoAttention, + MaskedLMOutput, +) + +try: + # Legacy mamba-ssm v1 file structure + from mamba_ssm.ops.triton.layernorm import ( + RMSNorm, layer_norm_fn, rms_norm_fn + ) +except ImportError: + try: + # mamba-ssm v2 file structure + from mamba_ssm.ops.triton.layer_norm import ( + RMSNorm, layer_norm_fn, rms_norm_fn + ) + except ImportError: + RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None +from mamba_ssm.ops.triton.selective_state_update import ( + selective_state_update, +) +from mamba_ssm.utils.generation import InferenceParams + +from models.dit import ( + LabelEmbedder, + TimestepEmbedder, + bias_dropout_add_scale_fused_inference, + bias_dropout_add_scale_fused_train, + modulate_fused, +) + +class Mamba(nn.Module): + def __init__( + self, + d_model, + d_state=16, + d_conv=4, + expand=2, + dt_rank='auto', + dt_min=0.001, + dt_max=0.1, + dt_init='random', + dt_scale=1.0, + dt_init_floor=1e-4, + conv_bias=True, + bias=False, + use_fast_path=True, # Fused kernel options + layer_idx=None, + device=None, + dtype=None, + ): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == 'auto' else dt_rank + self.use_fast_path = use_fast_path + self.layer_idx = layer_idx + + self.in_proj = nn.Linear( + self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs + ) + + self.conv1d = nn.Conv1d( + in_channels=self.d_inner, + out_channels=self.d_inner, + bias=conv_bias, + kernel_size=d_conv, + groups=self.d_inner, + padding=d_conv - 1, + **factory_kwargs, + ) + + self.activation = 'silu' + self.act = nn.SiLU() + + self.x_proj = nn.Linear( + self.d_inner, self.dt_rank + self.d_state * 2, + bias=False, **factory_kwargs) + self.dt_proj = nn.Linear( + self.dt_rank, self.d_inner, + bias=True, **factory_kwargs) + + # Initialize special dt projection to preserve variance at initialization + dt_init_std = self.dt_rank**-0.5 * dt_scale + if dt_init == 'constant': + nn.init.constant_(self.dt_proj.weight, dt_init_std) + elif dt_init == 'random': + nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) + else: + raise NotImplementedError + + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + dt = torch.exp( + torch.rand(self.d_inner, **factory_kwargs) + * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + self.dt_proj.bias.copy_(inv_dt) + # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit + self.dt_proj.bias._no_reinit = True + + # S4D real initialization + A = repeat( + torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), + 'n -> d n', + d=self.d_inner, + ).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + + # D 'skip' parameter + self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 + self.D._no_weight_decay = True + + self.out_proj = nn.Linear( + self.d_inner, self.d_model, bias=bias, **factory_kwargs + ) + + def forward(self, hidden_states, inference_params=None): + """ + hidden_states: (B, L, D) + Returns: same shape as hidden_states + """ + batch, seqlen, dim = hidden_states.shape + + conv_state, ssm_state = None, None + if inference_params is not None: + conv_state, ssm_state = self._get_states_from_cache( + inference_params, batch) + if inference_params.seqlen_offset > 0: + # The states are updated inplace + out, _, _ = self.step( + hidden_states, conv_state, ssm_state) + return out + + # We do matmul and transpose BLH -> HBL at the same time + xz = rearrange( + self.in_proj.weight @ rearrange(hidden_states, 'b l d -> d (b l)'), + 'd (b l) -> b d l', + l=seqlen, + ) + if self.in_proj.bias is not None: + xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), 'd -> d 1') + + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + # In the backward pass we write dx and dz next to each other to avoid torch.cat + + if ( + self.use_fast_path + and causal_conv1d_fn is not None + and inference_params is None + ): # Doesn't support outputting the states + out = mamba_inner_fn( + xz, + self.conv1d.weight, + self.conv1d.bias, + self.x_proj.weight, + self.dt_proj.weight, + self.out_proj.weight, + self.out_proj.bias, + A, + None, # input-dependent B + None, # input-dependent C + self.D.float(), + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + ) + + else: + x, z = xz.chunk(2, dim=1) + # Compute short convolution + if conv_state is not None: + # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + conv_state.copy_( + F.pad(x, (self.d_conv - x.shape[-1], 0)) + ) # Update state (B D W) + if causal_conv1d_fn is None: + x = self.act(self.conv1d(x)[..., :seqlen]) + else: + assert self.activation in ['silu', 'swish'] + x = causal_conv1d_fn( + x=x, + weight=rearrange(self.conv1d.weight, 'd 1 w -> d w'), + bias=self.conv1d.bias, + activation=self.activation, + state=conv_state,) + + # We're careful here about the layout, to avoid extra transposes. + # We want dt to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = self.x_proj(rearrange(x, 'b d l -> (b l) d')) # (bl d) + dt, B, C = torch.split( + x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1 + ) + dt = self.dt_proj.weight @ dt.t() + dt = rearrange(dt, 'd (b l) -> b d l', l=seqlen) + B = rearrange(B, '(b l) dstate -> b dstate l', l=seqlen).contiguous() + C = rearrange(C, '(b l) dstate -> b dstate l', l=seqlen).contiguous() + + assert self.activation in ['silu', 'swish'] + + y = selective_scan_fn( + x, + dt, + A, + B, + C, + self.D.float(), + z=z, + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + return_last_state=ssm_state is not None, + ) + + if ssm_state is not None: + y, last_state = y + ssm_state.copy_(last_state) + y = rearrange(y, 'b d l -> b l d') + + out = self.out_proj(y) + return out + + def step(self, hidden_states, conv_state, ssm_state): + dtype = hidden_states.dtype + assert ( + hidden_states.shape[1] == 1 + ), 'Only support decoding with 1 token at a time for now' + xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) + x, z = xz.chunk(2, dim=-1) # (B D) + + # Conv step + if causal_conv1d_update is None: + conv_state.copy_( + torch.roll(conv_state, shifts=-1, dims=-1) + ) # Update state (B D W) + conv_state[:, :, -1] = x + x = torch.sum( + conv_state * rearrange(self.conv1d.weight, 'd 1 w -> d w'), dim=-1 + ) # (B D) + if self.conv1d.bias is not None: + x = x + self.conv1d.bias + x = self.act(x).to(dtype=dtype) + else: + x = causal_conv1d_update( + x.to(dtype), + conv_state.to(dtype), + rearrange(self.conv1d.weight, 'd 1 w -> d w'), + self.conv1d.bias, + self.activation, + ) + + x_db = self.x_proj(x) # (B dt_rank+2*d_state) + dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) + # Don't add dt_bias here + dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + + # SSM step + if selective_state_update is None: + # Discretize A and B + dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) + dA = torch.exp(torch.einsum('bd,dn->bdn', dt, A)) + dB = torch.einsum('bd,bn->bdn', dt, B) + ssm_state.copy_(ssm_state * dA + rearrange(x, 'b d -> b d 1') * dB) + y = torch.einsum('bdn,bn->bd', ssm_state.to(dtype), C) + y = y + self.D.to(dtype) * x + y = y * self.act(z) # (B D) + else: + y = selective_state_update( + ssm_state, + x, + dt, + A, + B, + C, + self.D, + z=z, + dt_bias=self.dt_proj.bias, + dt_softplus=True, + ) + + out = self.out_proj(y) + return out.unsqueeze(1), conv_state, ssm_state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + device = self.out_proj.weight.device + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, + self.d_model * self.expand, + self.d_conv, + device=device, + dtype=conv_dtype, + ) + ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype + # ssm_dtype = torch.float32 + ssm_state = torch.zeros( + batch_size, + self.d_model * self.expand, + self.d_state, + device=device, + dtype=ssm_dtype, + ) + return conv_state, ssm_state + + def _get_states_from_cache( + self, inference_params, batch_size, initialize_states=False + ): + assert self.layer_idx is not None + if self.layer_idx not in inference_params.key_value_memory_dict: + batch_shape = (batch_size,) + conv_state = torch.zeros( + batch_size, + self.d_model * self.expand, + self.d_conv, + device=self.conv1d.weight.device, + dtype=self.conv1d.weight.dtype, + ) + ssm_state = torch.zeros( + batch_size, + self.d_model * self.expand, + self.d_state, + device=self.dt_proj.weight.device, + dtype=self.dt_proj.weight.dtype, + # dtype=torch.float32, + ) + inference_params.key_value_memory_dict[self.layer_idx] = ( + conv_state, + ssm_state, + ) + else: + conv_state, ssm_state = inference_params.key_value_memory_dict[ + self.layer_idx + ] + # TODO: What if batch size changes between generation, and we reuse the same states? + if initialize_states: + conv_state.zero_() + ssm_state.zero_() + return conv_state, ssm_state + + +class Block(nn.Module): + def __init__( + self, + dim, + mixer_cls, + norm_cls=nn.LayerNorm, + fused_add_norm=False, + residual_in_fp32=False, + use_adaLN=False, + cond_dim=0, + ): + """ + Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection' + + This Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA/MLP -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Add -> LN -> Mixer, returning both + the hidden_states (output of the mixer) and the residual. + This is purely for performance reasons, as we can fuse add and LayerNorm. + The residual needs to be provided (except for the very first block). + """ + super().__init__() + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + self.mixer = mixer_cls(dim) + self.norm = norm_cls(dim) + + if self.fused_add_norm: + assert RMSNorm is not None, 'RMSNorm import fails' + assert isinstance( + self.norm, (nn.LayerNorm, RMSNorm) + ), 'Only LayerNorm and RMSNorm are supported for fused_add_norm' + + self.dropout = 0.1 + + self.use_adaLN = use_adaLN + self.cond_dim = cond_dim + if use_adaLN: + self.adaLN_modulation = nn.Linear( + cond_dim, 3 * dim, bias=True) + self.adaLN_modulation.weight.data.zero_() + self.adaLN_modulation.bias.data.zero_() + + def _get_bias_dropout_scale(self): + return bias_dropout_add_scale_fused_train if self.training else bias_dropout_add_scale_fused_inference + + def forward( + self, + hidden_states: Tensor, + residual: Optional[Tensor] = None, + cond_embeds: Optional[Tensor] = None, + inference_params: Optional[InferenceParams] = None, + ): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: hidden_states = Mixer(LN(residual)) + cond_embeds: conditional embeddings for modulation (optional). + inference_params: parameters for inference (optional). + """ + if not self.fused_add_norm: + residual = ( + (hidden_states + residual) + if residual is not None + else hidden_states + ) + + hidden_states = self.norm( + residual.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + fused_add_norm_fn = ( + rms_norm_fn + if isinstance(self.norm, RMSNorm) + else layer_norm_fn + ) + + hidden_states, residual = fused_add_norm_fn( + hidden_states, + self.norm.weight, + self.norm.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm.eps) + + if self.use_adaLN and cond_embeds is not None: + (shift_msa, + scale_msa, + gate_msa) = self.adaLN_modulation( + cond_embeds)[:, None].chunk(3, dim=-1) + hidden_states = modulate_fused(hidden_states, + shift_msa, + scale_msa) + else: + gate_msa = None + + mixer_out = self.mixer(hidden_states, inference_params=inference_params) + + hidden_states = mixer_out + if self.use_adaLN and cond_embeds is not None: + bias_dropout_scale_fn = self._get_bias_dropout_scale() + hidden_states = bias_dropout_scale_fn( + hidden_states, + None, + gate_msa, + residual, + self.dropout) + + return hidden_states, residual + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache( + batch_size, max_seqlen, dtype=dtype, **kwargs) + + +class BiMambaConfig(PretrainedConfig): + """Config that extends the original MambaConfig with params relevant to bi-directionality.""" + + model_type = 'bimamba' + + def __init__( + self, + # From original MambaConfig + d_model: int = 2560, + n_layer: int = 64, + vocab_size: int = 50277, + ssm_cfg: Optional[dict] = None, + rms_norm: bool = True, + residual_in_fp32: bool = True, + fused_add_norm: bool = True, + pad_vocab_size_multiple: int = 8, + tie_word_embeddings: bool = True, + # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm + norm_epsilon: float = 1e-5, + # Used in init_weights + initializer_cfg: Optional[dict] = None, + # Caduceus-specific params + bidirectional: bool = True, + bidirectional_strategy: Union[str, None] = 'add', + bidirectional_weight_tie: bool = True, + use_adaLN: bool = True, + cond_dim: int = 128, + **kwargs, + ): + super().__init__(**kwargs) + self.d_model = d_model + self.n_layer = n_layer + self.vocab_size = vocab_size + self.ssm_cfg = ssm_cfg + self.rms_norm = rms_norm + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + self.pad_vocab_size_multiple = pad_vocab_size_multiple + self.tie_word_embeddings = tie_word_embeddings + self.norm_epsilon = norm_epsilon + self.initializer_cfg = initializer_cfg + self.bidirectional = bidirectional + self.bidirectional_strategy = bidirectional_strategy + self.bidirectional_weight_tie = bidirectional_weight_tie + self.use_adaLN = use_adaLN + self.cond_dim = cond_dim + +def create_block( + d_model, + ssm_cfg=None, + norm_epsilon=1e-5, + rms_norm=False, + residual_in_fp32=False, + fused_add_norm=False, + layer_idx=None, + bidirectional=True, + bidirectional_strategy='add', + bidirectional_weight_tie=True, + device=None, + dtype=None, + use_adaLN=False, + cond_dim=0, +): + """Create BiMamba block. + + Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py + """ + if ssm_cfg is None: + ssm_cfg = {} + factory_kwargs = {'device': device, 'dtype': dtype} + bidirectional_kwargs = { + 'bidirectional': bidirectional, + 'bidirectional_strategy': bidirectional_strategy, + 'bidirectional_weight_tie': bidirectional_weight_tie, + } + mixer_cls = partial( + BiMambaWrapper, + layer_idx=layer_idx, + **ssm_cfg, + **bidirectional_kwargs, + **factory_kwargs, + ) + norm_cls = partial( + nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs + ) + block_cls = Block + block = block_cls( + d_model, + mixer_cls, + norm_cls=norm_cls, + fused_add_norm=fused_add_norm, + residual_in_fp32=residual_in_fp32, + use_adaLN=use_adaLN, + cond_dim=cond_dim, + ) + block.layer_idx = layer_idx + + return block + + +class BiMambaWrapper(nn.Module): + """Thin wrapper around Mamba to support bi-directionality.""" + + def __init__( + self, + d_model: int, + bidirectional: bool = True, + bidirectional_strategy: Optional[str] = 'add', + bidirectional_weight_tie: bool = True, + **mamba_kwargs, + ): + super().__init__() + if bidirectional and bidirectional_strategy is None: + bidirectional_strategy = 'add' # Default strategy: `add` + if bidirectional and bidirectional_strategy not in ['add', 'ew_multiply']: + raise NotImplementedError( + f'`{bidirectional_strategy}` strategy for bi-directionality is not implemented!' + ) + self.bidirectional = bidirectional + self.bidirectional_strategy = bidirectional_strategy + + self.mamba_fwd = Mamba(d_model=d_model, **mamba_kwargs) + + self.mamba_rev = None + if bidirectional: + self.mamba_rev = Mamba(d_model=d_model, **mamba_kwargs) + if bidirectional_weight_tie: # Tie in and out projections (where most of param count lies) + self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight + self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias + self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight + self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias + else: + self.mamba_rev = None + + def forward(self, hidden_states, inference_params=None): + """Bidirectional-enabled forward pass + + hidden_states: (B, L, D) + Returns: same shape as hidden_states + """ + + out = self.mamba_fwd( + hidden_states, inference_params=inference_params,) + + if self.bidirectional: + if inference_params is not None: + raise NotImplementedError( + 'Passing `inference_params` not supported ' + 'for bidirectional Mamba.') + + hidden_states_flipped = torch.flip(hidden_states, dims=(1,)) + + out_rev = self.mamba_rev( + hidden_states_flipped, # Flip along the sequence length dimension + inference_params=inference_params,) + + out_rev_flipped = torch.flip(out_rev, dims=(1,)) + if self.bidirectional_strategy == 'add': + out = out + out_rev_flipped # Flip back for combining with forward hidden states + elif self.bidirectional_strategy == 'ew_multiply': + out = out * out_rev_flipped + else: + raise NotImplementedError( + f"`{self.bidirectional_strategy}` for " + f"bi-directionality not implemented!") + return out + + def allocate_inference_cache( + self, batch_size, max_seqlen, dtype=None, **kwargs): + if self.bidirectional: + raise NotImplementedError( + 'Allocating inference cache not supported ' + 'for bidirectional Mamba.') + return self.mamba_fwd.allocate_inference_cache( + batch_size, max_seqlen, dtype=dtype, **kwargs) + + +class BiMambaEmbeddings(nn.Module): + def __init__( + self, + config: BiMambaConfig, + input_dim=None, + device=None, + dtype=None, + ): + super().__init__() + factory_kwargs = {'device': device, 'dtype': dtype} + if input_dim is None: + input_dim = config.vocab_size + self.word_embeddings = nn.Embedding( + input_dim, config.d_model, **factory_kwargs + ) + + def forward(self, input_ids): + """ + input_ids: (batch, seqlen) + """ + return self.word_embeddings(input_ids) + + +class BiMambaMixerModel(nn.Module): + def __init__( + self, + config: BiMambaConfig, + device=None, + dtype=None, + ) -> None: + super().__init__() + factory_kwargs = {'device': device, 'dtype': dtype} + self.config = config + input_dim = config.vocab_size + d_model = config.d_model + + self.fused_add_norm = config.fused_add_norm + self.residual_in_fp32 = config.residual_in_fp32 + + self.embeddings = BiMambaEmbeddings( + config, input_dim=input_dim, **factory_kwargs) + + # Mamba changes the order of residual and layer norm: + # Instead of LN -> Attn / MLP -> Add, we do: + # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and + # the main branch (output of MLP / Mixer). The model definition is unchanged. + # This is for performance reason: we can fuse add + layer_norm. + if config.fused_add_norm: + if layer_norm_fn is None or rms_norm_fn is None: + raise ImportError('Failed to import Triton LayerNorm / RMSNorm kernels') + + self.layers = nn.ModuleList( + [ + create_block( + d_model, + ssm_cfg=config.ssm_cfg, + norm_epsilon=config.norm_epsilon, + rms_norm=config.rms_norm, + residual_in_fp32=config.residual_in_fp32, + fused_add_norm=config.fused_add_norm, + layer_idx=i, + bidirectional=config.bidirectional, + bidirectional_strategy=config.bidirectional_strategy, + bidirectional_weight_tie=config.bidirectional_weight_tie, + use_adaLN=config.use_adaLN, + cond_dim=config.cond_dim, + **factory_kwargs, + ) + for i in range(config.n_layer) + ] + ) + + if config.use_adaLN: + self.adaLN_modulation_final = nn.Linear( + config.cond_dim, 2 * d_model, bias=True) + self.adaLN_modulation_final.weight.data.zero_() + self.adaLN_modulation_final.bias.data.zero_() + else: + self.adaLN_modulation_final = None + + norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)( + d_model, eps=config.norm_epsilon, **factory_kwargs) + self.norm_f = norm_f + + def forward( + self, + input_ids: torch.LongTensor, + inputs_embeds: Optional[torch.FloatTensor] = None, + hidden_states: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + cond_embeds: Optional[torch.Tensor] = None, + inference_params: Optional[InferenceParams] = None + ): + + """Mixer forward.""" + all_hidden_states = [] + if hidden_states is None: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + if input_ids.ndim == 2: # indices (B, L) + hidden_states = self.embeddings(input_ids) + else: # one-hots (B, L, V) + hidden_states = F.linear( + input_ids.to(torch.float), + self.embeddings.word_embeddings.weight.T) + + residual = None + for ind, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states.append(hidden_states) + # TODO: Add support for gradient checkpointing + layer_out = layer( + hidden_states, residual, + inference_params=inference_params, + cond_embeds=cond_embeds + ) + + hidden_states, residuals = layer_out + + if not self.fused_add_norm: + if self.config.use_adaLN: + raise NotImplementedError('adaln only implemented for fused_add_norm') + residual = ( + (hidden_states + residual) if residual is not None else hidden_states + ) + hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) + else: + if cond_embeds is not None and self.config.use_adaLN: + shift, scale = self.adaLN_modulation_final( + cond_embeds)[:, None].chunk(2, dim=2) + else: + shift, scale = None, None + + fused_add_norm_fn = ( + rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn + ) + + # Set prenorm=False here since we don't need the residual + hidden_states = fused_add_norm_fn( + hidden_states, + self.norm_f.weight, + self.norm_f.bias, + eps=self.norm_f.eps, + residual=residual, + prenorm=False, + residual_in_fp32=self.residual_in_fp32, + ) + if cond_embeds is not None and self.config.use_adaLN: + hidden_states = modulate_fused(hidden_states, shift, scale) + else: + if cond_embeds is not None and self.config.use_adaLN: + shift, scale = self.adaLN_modulation_final( + cond_embeds)[:, None].chunk(2, dim=2) + hidden_states = modulate_fused(hidden_states, shift, scale) + + if output_hidden_states: + all_hidden_states.append(hidden_states) + + return hidden_states, all_hidden_states + + def allocate_inference_cache( + self, batch_size, max_seqlen, dtype=None, **kwargs): + return { + i: layer.allocate_inference_cache( + batch_size, max_seqlen, dtype=dtype, **kwargs) + for i, layer in enumerate(self.layers) + } + + +def cross_entropy(logits, y, ignore_index=-100): + """Cross-entropy loss.""" + logits = logits.view(-1, logits.shape[-1]) + y = y.view(-1) + return F.cross_entropy(logits, y, ignore_index=ignore_index) + + +def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100): + """Weighted cross-entropy loss (discounts certain tokens).""" + logits = logits.view(-1, logits.shape[-1]) + y = y.view(-1) + ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction='none') + loss_weights = loss_weights.view(-1) + loss_weights[y == ignore_index] = 0.0 + return (ce * (loss_weights / loss_weights.sum())).sum() + + +class BiMambaPreTrainedModel(PreTrainedModel): + """PreTrainedModel wrapper for BiMamba backbone.""" + + config_class = BiMambaConfig + base_model_prefix = 'bimamba' + supports_gradient_checkpointing = False + _no_split_modules = ['BiMambaWrapper'] + + def _init_weights( + self, + module, + initializer_range=0.02, # Now only used for embedding layer. + **kwargs, + ): + """Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py""" + + n_layer = self.config.n_layer + initialized_cfg = self.config.initializer_cfg if self.config.initializer_cfg is not None else {} + rescale_prenorm_residual = initialized_cfg.get('rescale_prenorm_residual', True) + initializer_range = initialized_cfg.get('initializer_range', initializer_range) + n_residuals_per_layer = initialized_cfg.get('n_residuals_per_layer', 1) + + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, '_no_reinit', False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. + # > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of + # residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ['out_proj.weight', 'fc2.weight']: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(n_residuals_per_layer * n_layer) + + +class BiMamba(BiMambaPreTrainedModel): + """BiMamba model that can be instantiated using HF patterns.""" + + def __init__(self, config: BiMambaConfig, device=None, dtype=None, **kwargs): + super().__init__(config) + + # Adjust vocab size if vocab padding is set. + if config.vocab_size % config.pad_vocab_size_multiple != 0: + config.vocab_size += config.pad_vocab_size_multiple - ( + config.vocab_size % config.pad_vocab_size_multiple + ) + + self.config = config + factory_kwargs = {'device': device, 'dtype': dtype} + self.backbone = BiMambaMixerModel(config, **factory_kwargs, **kwargs) + + def forward( + self, + input_ids: torch.LongTensor = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + hidden_states: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cond_embeds: Optional[torch.Tensor] = None, + inference_params: Optional[InferenceParams] = None, + ) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]: + """HF-compatible forward method.""" + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + backbone_out = self.backbone( + input_ids, + inputs_embeds=inputs_embeds, + hidden_states=hidden_states, + output_hidden_states=output_hidden_states, + cond_embeds=cond_embeds, + inference_params=inference_params, + ) + + hidden_states, all_hidden_states = backbone_out + + if return_dict: + return BaseModelOutputWithNoAttention( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states if output_hidden_states else None, + ) + elif output_hidden_states: + return hidden_states, all_hidden_states + else: + return hidden_states + + +class BiMambaForMaskedLM(BiMambaPreTrainedModel): + """HF-compatible BiMamba model for masked language modeling.""" + + def __init__(self, config: BiMambaConfig, device=None, dtype=None, **kwargs): + super().__init__(config, **kwargs) + factory_kwargs = {'device': device, 'dtype': dtype} + self.bimamba = BiMamba(config, **factory_kwargs, **kwargs) + self.config = config + lm_head_in_dim = config.d_model + # LM head may only take in concatenated timestep embeddings + # if its weights are not tied to the vocab embedding + self.lm_head = nn.Linear( + lm_head_in_dim, + self.config.vocab_size, # Use BiMamba config as it might have been updated + bias=False, + **factory_kwargs, + ) + # Initialize weights and apply final processing + self.post_init() + if self.config.tie_word_embeddings: + self.tie_weights() + + def init_weights(self): + """ + If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any + initialization logic in `_init_weights`. + """ + + # Initialize weights + self.apply(self._initialize_weights) + + # Tie weights should be skipped when not initializing all weights + # since from_pretrained(...) calls tie weights anyway. + + def post_init(self): + """ + A method executed at the end of each Transformer model initialization, to execute code that needs the model's + modules properly initialized (such as weight initialization). + """ + self.init_weights() + self._backward_compatibility_gradient_checkpointing() + + def get_input_embeddings(self): + return self.bimamba.backbone.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.bimamba.backbone.embeddings.word_embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + """Overrides output embeddings.""" + self.lm_head = new_embeddings + + def tie_weights(self): + """Tie weights.""" + super().tie_weights() + + def get_encoder(self): + """Get encoder (backbone) for the model.""" + return self.bimamba + + def set_encoder(self, encoder): + """Set encoder (backbone) for the model.""" + self.bimamba = encoder + + def forward( + self, + input_ids: torch.LongTensor = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + hidden_states: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + loss_weights: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cond_embeds: Optional[torch.FloatTensor] = None, + inference_params: Optional[InferenceParams] = None, + num_last_tokens: int = 0 + ) -> Union[Tuple, MaskedLMOutput]: + """HF-compatible forward method.""" + + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.bimamba( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + hidden_states=hidden_states, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cond_embeds=cond_embeds, + inference_params=inference_params, + ) + hidden_states = outputs[0] + + if num_last_tokens > 0: + hidden_states = hidden_states[:, -num_last_tokens:] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if loss_weights is not None: + loss = weighted_cross_entropy( + logits, labels, loss_weights, ignore_index=self.config.pad_token_id + ) + else: + loss = cross_entropy( + logits, labels, ignore_index=self.config.pad_token_id + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return MaskedLMOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states,) + + +class DiMamba(nn.Module, huggingface_hub.PyTorchModelHubMixin): + def __init__(self, config, vocab_size: int, pad_token_id: int): + super().__init__() + if type(config) == dict: + config = omegaconf.OmegaConf.create(config) + + if config.parameterization == 'ar': + self.sigma_map = None + else: + self.sigma_map = TimestepEmbedder(config.model.cond_dim) + if (config.training.guidance is not None or # Training for / using CFG + (hasattr(config, 'guidance') + and config.guidance is not None + and config.guidance.method == 'cfg')): + self.cond_map = LabelEmbedder( + config.data.num_classes + 1, # +1 for mask + config.model.cond_dim) + else: + self.cond_map = None + + mamba_config = BiMambaConfig( + d_model=config.model.hidden_size, + n_layer=config.model.n_blocks, + pad_token_id=pad_token_id, + vocab_size=vocab_size, + pad_vocab_size_multiple=1, + tie_word_embeddings=config.model.tie_word_embeddings, + bidirectional=getattr(config.model, 'bidirectional', True), + bidirectional_strategy=getattr(config.model, 'bidirectional_strategy', 'add'), + bidirectional_weight_tie=getattr(config.model, 'bidirectional_weight_tie', True), + use_adaLN=self.sigma_map is not None or self.cond_map is not None, + cond_dim=config.model.cond_dim, + ) + + self.model = BiMambaForMaskedLM(config=mamba_config) + + def _get_bias_dropout_scale(self): + if self.training: + return bias_dropout_add_scale_fused_train + else: + return bias_dropout_add_scale_fused_inference + + def forward( + self, + indices, + sigma, + cond=None, + x_emb=None, + return_hidden_states=False, + inference_params=None + ): + c = None + if self.sigma_map is not None: + c = F.silu(self.sigma_map(sigma)) + if cond is not None: + if self.cond_map is None: + raise ValueError("Conditioning variable provided, " + "but Model was not initialized " + "with condition embedding layer.") + else: + c = c + F.silu(self.cond_map(cond)) if c is not None \ + else F.silu(self.cond_map(cond)) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + model_out = self.model( + indices, + hidden_states=x_emb, + cond_embeds=c, + output_hidden_states=return_hidden_states, + inference_params=inference_params + ) + + if return_hidden_states: + return model_out.logits, model_out.hidden_states + return model_out.logits + + +class DiMambaClassifier(nn.Module): + def __init__(self, config, vocab_size: int, pad_token_id: int): + super().__init__() + if type(config) == dict: + config = omegaconf.OmegaConf.create(config) + + if config.parameterization == 'ar': + self.sigma_map = None + else: + self.sigma_map = TimestepEmbedder(config.classifier_model.cond_dim) + + mamba_config = BiMambaConfig( + d_model=config.classifier_model.hidden_size, + n_layer=config.classifier_model.n_blocks, + pad_token_id=pad_token_id, + vocab_size=vocab_size, + pad_vocab_size_multiple=1, + tie_word_embeddings=config.classifier_model.tie_word_embeddings, + bidirectional=getattr(config.classifier_model, 'bidirectional', True), + bidirectional_strategy=getattr(config.classifier_model, 'bidirectional_strategy', 'add'), + bidirectional_weight_tie=getattr(config.classifier_model, 'bidirectional_weight_tie', True), + use_adaLN=self.sigma_map is not None, + cond_dim=config.classifier_model.cond_dim, + ) + + self.model = BiMamba(config=mamba_config) + self.pooling = getattr(config.classifier_model, 'pooling', 'mean') + self.output_layer = nn.Linear( + config.classifier_model.hidden_size, + config.classifier_model.num_classes) + + def _get_bias_dropout_scale(self): + if self.training: + return bias_dropout_add_scale_fused_train + else: + return bias_dropout_add_scale_fused_inference + + def forward( + self, + indices_or_one_hots, + sigma, + x_emb=None, + attention_mask=None + ): + c = None + if self.sigma_map is not None: + c = F.silu(self.sigma_map(sigma)) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + x = self.model( + indices_or_one_hots, + hidden_states=x_emb, + cond_embeds=c, + output_hidden_states=False, + inference_params=None + )[0] + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + if self.pooling == 'mean': + x = x.mean(dim=1) + elif self.pooling == 'max': + x = x.max(dim=1) + elif self.pooling == 'cls': + x = x[..., 0] + elif self.pooling == 'last': + x = x[..., -1] + elif self.pooling == 'no_pooling': # used for ar_fudge + pass + elif self.pooling == 'attention_mean': # used for ar_pplm + masked_x = x * attention_mask.unsqueeze(2) + x = torch.sum(masked_x, dim=1) / ( + torch.sum(attention_mask, dim=1, + keepdim=True) + 1e-15) + else: + raise NotImplementedError( + f"`{self.pooling}` method not implemented.") + x = self.output_layer(x) + return x + + def load_pretrained_encoder(self, encoder: nn.Module): + self.sigma_map = encoder.sigma_map + self.model = encoder.model.bimamba diff --git a/models/dit.py b/models/dit.py new file mode 100644 index 0000000000000000000000000000000000000000..5f89a812d5046f0659e7898b7aedd57db460da90 --- /dev/null +++ b/models/dit.py @@ -0,0 +1,556 @@ +import math +import typing + +import flash_attn +import flash_attn.layers.rotary +import huggingface_hub +import omegaconf +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +# Flags required to enable jit fusion kernels +torch._C._jit_set_profiling_mode(False) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_override_can_fuse_on_cpu(True) +torch._C._jit_override_can_fuse_on_gpu(True) + + +def bias_dropout_add_scale( + x: torch.Tensor, + bias: typing.Optional[torch.Tensor], + scale: typing.Optional[torch.Tensor], + residual: typing.Optional[torch.Tensor], + prob: float, + training: bool) -> torch.Tensor: + if bias is None: + out = F.dropout(x, p=prob, training=training) + else: + out = F.dropout(x + bias, p=prob, training=training) + if scale is not None: + out = scale * out + if residual is not None: + out = residual + out + return out + + +def get_bias_dropout_add_scale(training): + def _bias_dropout_add(x, bias, scale, residual, prob): + return bias_dropout_add_scale( + x, bias, scale, residual, prob, training) + + return _bias_dropout_add + + +# function overload +def modulate(x: torch.Tensor, + shift: torch.Tensor, + scale: torch.Tensor) -> torch.Tensor: + return x * (1 + scale) + shift + + +@torch.jit.script +def bias_dropout_add_scale_fused_train( + x: torch.Tensor, + bias: typing.Optional[torch.Tensor], + scale: typing.Optional[torch.Tensor], + residual: typing.Optional[torch.Tensor], + prob: float) -> torch.Tensor: + return bias_dropout_add_scale( + x, bias, scale, residual, prob, True) + + +@torch.jit.script +def bias_dropout_add_scale_fused_inference( + x: torch.Tensor, + bias: typing.Optional[torch.Tensor], + scale: typing.Optional[torch.Tensor], + residual: typing.Optional[torch.Tensor], + prob: float) -> torch.Tensor: + return bias_dropout_add_scale( + x, bias, scale, residual, prob, False) + + +@torch.jit.script +def modulate_fused(x: torch.Tensor, + shift: torch.Tensor, + scale: torch.Tensor) -> torch.Tensor: + return modulate(x, shift, scale) + + +class Rotary(torch.nn.Module): + def __init__(self, dim, base=10_000): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + self.seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + + def forward(self, x, seq_dim=1): + seq_len = x.shape[seq_dim] + if seq_len != self.seq_len_cached: + self.seq_len_cached = seq_len + t = torch.arange( + x.shape[seq_dim], + device=x.device).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", t, + self.inv_freq.clone()) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + # dims are: batch, seq_len, qkv, head, dim + self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1) + self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1) + # This makes the transformation on v an identity. + self.cos_cached[:,:,2,:,:].fill_(1.) + self.sin_cached[:,:,2,:,:].fill_(0.) + + return self.cos_cached, self.sin_cached + + +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(qkv, cos, sin): + cos = cos[0,:,0,0,:cos.shape[-1]//2] + sin = sin[0,:,0,0,:sin.shape[-1]//2] + return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, + cos, + sin) + +# function overload +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +############################################################ +# Layers # +############################################################ +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.weight = nn.Parameter(torch.ones([dim])) + self.dim = dim + def forward(self, x): + with torch.cuda.amp.autocast(enabled=False): + x = F.layer_norm(x.float(), [self.dim]) + return x * self.weight[None,None,:] + + +def residual_linear(x, W, x_skip, residual_scale): + """x_skip + residual_scale * W @ x""" + dim_out, dim_in = W.shape[0], W.shape[1] + return torch.addmm( + x_skip.view(-1, dim_out), + x.view(-1, dim_in), + W.T, + alpha=residual_scale).view(*x.shape[:-1], dim_out) + + +############################################################ +# Embedding Layers for Timesteps and Class Labels # +############################################################ +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, + frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, + bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True)) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, 1 / batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the + embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + - math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat( + [torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, + torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding( + t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class LabelEmbedder(nn.Module): + """Embeds class labels into vector representations.""" + def __init__(self, num_classes, cond_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, + cond_size) + self.num_classes = num_classes + + def forward(self, labels): + embeddings = self.embedding_table(labels) + # embeddings = self.mlp(embeddings) + return embeddings + + +############################################################ +# Core Model # +############################################################ + + +class DDiTBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + cond_dim, + mlp_ratio=4, + dropout=0.1, + causal=False, + use_adaLN=True, + ): + super().__init__() + self.n_heads = n_heads + self.causal = causal + + self.norm1 = LayerNorm(dim) + self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False) + self.attn_out = nn.Linear(dim, dim, bias=False) + self.dropout1 = nn.Dropout(dropout) + + self.norm2 = LayerNorm(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, mlp_ratio * dim, bias=True), + nn.GELU(approximate='tanh'), + nn.Linear(mlp_ratio * dim, dim, bias=True)) + self.dropout2 = nn.Dropout(dropout) + self.dropout = dropout + + self.use_adaLN = use_adaLN + if use_adaLN: + self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, + bias=True) + self.adaLN_modulation.weight.data.zero_() + self.adaLN_modulation.bias.data.zero_() + + def _get_bias_dropout_scale(self): + if self.training: + return bias_dropout_add_scale_fused_train + else: + return bias_dropout_add_scale_fused_inference + + + def forward(self, x, rotary_cos_sin, c, seqlens=None): + batch_size, seq_len = x.shape[0], x.shape[1] + + bias_dropout_scale_fn = self._get_bias_dropout_scale() + + if self.use_adaLN: + (shift_msa, scale_msa, gate_msa, shift_mlp, + scale_mlp, gate_mlp) = self.adaLN_modulation( + c)[:, None].chunk(6, dim=2) + else: + (shift_msa, scale_msa, gate_msa, + shift_mlp, scale_mlp, gate_mlp) = ( + None, None, None, None, None, None) + + # attention operation + x_skip = x + if self.use_adaLN: + x = modulate_fused(self.norm1(x), shift_msa, scale_msa) + else: + x = self.norm1(x) + + qkv = self.attn_qkv(x) + qkv = rearrange( + qkv, + 'b s (three h d) -> b s three h d', + three=3, + h=self.n_heads) + with torch.cuda.amp.autocast(enabled=False): + cos, sin = rotary_cos_sin + qkv = apply_rotary_pos_emb( + qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)) + qkv = rearrange(qkv, 'b s ... -> (b s) ...') + if seqlens is None: + cu_seqlens = torch.arange( + 0, (batch_size + 1) * seq_len, step=seq_len, + dtype=torch.int32, device=qkv.device) + else: + cu_seqlens = seqlens.cumsum(-1) + x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func( + qkv, cu_seqlens, seq_len, 0., causal=self.causal) + + x = rearrange(x, '(b s) h d -> b s (h d)', b=batch_size) + + x = bias_dropout_scale_fn(self.attn_out(x), + None, + gate_msa, + x_skip, + self.dropout) + + # mlp operation + x_skip = x + if self.use_adaLN: + x = modulate_fused(self.norm2(x), shift_mlp, scale_mlp) + else: + x = self.norm2(x) + return bias_dropout_scale_fn( + self.mlp(x),None, gate_mlp, x_skip, self.dropout) + + +class EmbeddingLayer(nn.Module): + def __init__(self, dim, vocab_dim): + super().__init__() + self.embedding = nn.Parameter( + torch.empty((vocab_dim, dim))) + torch.nn.init.kaiming_uniform_( + self.embedding, a=math.sqrt(5)) + + def forward(self, x): + return self.embedding[x] + + +class DDitFinalLayer(nn.Module): + def __init__( + self, hidden_size, out_channels, cond_dim, + use_adaLN=True): + super().__init__() + self.norm_final = LayerNorm(hidden_size) + self.linear = nn.Linear(hidden_size, out_channels) + self.linear.weight.data.zero_() + self.linear.bias.data.zero_() + + if use_adaLN: + self.adaLN_modulation = nn.Linear(cond_dim, + 2 * hidden_size, + bias=True) + self.adaLN_modulation.weight.data.zero_() + self.adaLN_modulation.bias.data.zero_() + else: + self.adaLN_modulation = None + + def forward(self, x, c): + if self.adaLN_modulation is not None: + shift, scale = self.adaLN_modulation( + c)[:, None].chunk(2, dim=2) + x = modulate_fused(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + return self.linear(self.norm_final(x)) + + +class DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin): + def __init__(self, config, vocab_size): + super().__init__() + if type(config) == dict: + config = omegaconf.OmegaConf.create(config) + + self.config = config + self.vocab_size = vocab_size + self.causal = config.parameterization == 'ar' + + self.vocab_embed = EmbeddingLayer( + config.model.hidden_size, vocab_size) + + if self.causal: + self.sigma_map = None # no timestep embedding for AR + else: + self.sigma_map = TimestepEmbedder(config.model.cond_dim) + + if (config.training.guidance is not None or # Training for / using CFG + (hasattr(config, 'guidance') + and config.guidance is not None + and config.guidance.method == 'cfg')): + self.cond_map = LabelEmbedder( + config.data.num_classes + 1, # +1 for mask + config.model.cond_dim) + else: + self.cond_map = None + self.rotary_emb = Rotary( + config.model.hidden_size // config.model.n_heads) + + blocks = [] + use_adaLN = (config.parameterization != 'ar' or + self.cond_map is not None) + for _ in range(config.model.n_blocks): + blocks.append( + DDiTBlock( + config.model.hidden_size, + config.model.n_heads, + config.model.cond_dim, + dropout=config.model.dropout, + causal=self.causal, + use_adaLN=use_adaLN,)) + self.blocks = nn.ModuleList(blocks) + + self.output_layer = DDitFinalLayer( + config.model.hidden_size, + vocab_size, + config.model.cond_dim, + use_adaLN=use_adaLN) + self.scale_by_sigma = config.model.scale_by_sigma + + def _get_bias_dropout_scale(self): + if self.training: + return bias_dropout_add_scale_fused_train + else: + return bias_dropout_add_scale_fused_inference + + def forward(self, indices, sigma, + cond=None, + x_emb=None, + return_hidden_states=False): + if return_hidden_states: + hidden_states = [] + + if self.causal: + c = None + else: + c = F.silu(self.sigma_map(sigma)) + if cond is not None: + if self.cond_map is None: + raise ValueError("Conditioning variable provided, " + "but Model was not initialized " + "with condition embedding layer.") + else: + if c is None: # AR (self.causal is True) + c = F.silu(self.cond_map(cond)) + else: + c = c + F.silu(self.cond_map(cond)) + + if x_emb is None: + x = self.vocab_embed(indices) + if return_hidden_states: + hidden_states.append(x) + rotary_cos_sin = self.rotary_emb(x) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + for i in range(len(self.blocks)): + x = self.blocks[i](x, rotary_cos_sin, c, + seqlens=None) + if return_hidden_states: + hidden_states.append(x) + else: + x = x_emb + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + x = self.output_layer(x, c) + + if return_hidden_states: + return x, hidden_states + return x + +class DITClassifier(nn.Module): + def __init__(self, config, vocab_size): + super().__init__() + if type(config) == dict: + config = omegaconf.OmegaConf.create(config) + + self.config = config + self.vocab_size = vocab_size + self.causal = config.parameterization == 'ar' + + self.vocab_embed = EmbeddingLayer( + config.classifier_model.hidden_size, vocab_size) + + if self.causal: + self.sigma_map = None + else: + self.sigma_map = TimestepEmbedder(config.classifier_model.cond_dim) + + self.rotary_emb = Rotary( + config.classifier_model.hidden_size // config.classifier_model.n_heads) + + blocks = [] + use_adaLN = config.parameterization != 'ar' + for _ in range(config.classifier_model.n_blocks): + blocks.append( + DDiTBlock(config.classifier_model.hidden_size, + config.classifier_model.n_heads, + config.classifier_model.cond_dim, + dropout=config.classifier_model.dropout, + causal=self.causal, + use_adaLN=use_adaLN)) + self.blocks = nn.ModuleList(blocks) + + self.scale_by_sigma = config.classifier_model.scale_by_sigma + + self.pooling = getattr(config.classifier_model, 'pooling', 'mean') + self.output_layer = nn.Linear( + config.classifier_model.hidden_size, + config.classifier_model.num_classes) + + def _get_bias_dropout_scale(self): + if self.training: + return bias_dropout_add_scale_fused_train + else: + return bias_dropout_add_scale_fused_inference + + def forward(self, indices_or_one_hots, sigma, x_emb=None, attention_mask=None): + if x_emb is None: + if indices_or_one_hots.ndim == 2: # indices (B, L) + x = self.vocab_embed(indices_or_one_hots) + else: # one-hots (B, L, V) + x = F.linear(indices_or_one_hots.to(torch.float), + self.vocab_embed.embedding.T) + + if self.causal: + c = None + else: + c = F.silu(self.sigma_map(sigma)) + + rotary_cos_sin = self.rotary_emb(x) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + for i in range(len(self.blocks)): + x = self.blocks[i](x, rotary_cos_sin, c, + seqlens=None) + else: + x = x_emb + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + if self.pooling == 'mean': + x = x.mean(dim=1) + elif self.pooling == 'max': + x = x.max(dim=1) + elif self.pooling == 'cls': + x = x[..., 0] + elif self.pooling == 'last': + x = x[..., -1] + elif self.pooling == 'no_pooling': # for ar_fudge + pass + elif self.pooling == 'attention_mean': # for ar_pplm + masked_x = x * attention_mask.unsqueeze(2) + x = torch.sum(masked_x, dim=1) / (torch.sum(attention_mask, dim=1, keepdim=True) + 1e-15) + else: + raise NotImplementedError( + f"`{self.pooling}` method not implemented.") + x = self.output_layer(x) + return x + + def load_pretrained_encoder(self, encoder: nn.Module): + self.vocab_embed = encoder.vocab_embed + self.sigma_map = encoder.sigma_map + self.rotary_emb = encoder.rotary_emb + self.blocks = encoder.blocks diff --git a/models/ema.py b/models/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..7ae05fdd97b430c82dfb39adedaf478ff0d655dd --- /dev/null +++ b/models/ema.py @@ -0,0 +1,101 @@ +import torch + + +class ExponentialMovingAverage: + """ + Maintains (exponential) moving average of a set of parameters. + """ + + def __init__(self, parameters, decay, use_num_updates=True): + """ + Args: + parameters: Iterable of `torch.nn.Parameter`; usually the result of + `model.parameters()`. + decay: The exponential decay. + use_num_updates: Whether to use number of updates when computing + averages. + """ + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + self.decay = decay + self.num_updates = 0 if use_num_updates else None + self.shadow_params = [p.clone().detach() + for p in parameters if p.requires_grad] + self.collected_params = [] + + def move_shadow_params_to_device(self, device): + self.shadow_params = [i.to(device) for i in self.shadow_params] + + def update(self, parameters): + """ + Update currently maintained parameters. + + Call this every time the parameters are updated, such as the result of + the `optimizer.step()` call. + + Args: + parameters: Iterable of `torch.nn.Parameter`; usually the same set of + parameters used to initialize this object. + """ + decay = self.decay + if self.num_updates is not None: + self.num_updates += 1 + decay = min(decay, (1 + self.num_updates) / + (10 + self.num_updates)) + one_minus_decay = 1.0 - decay + with torch.no_grad(): + parameters = [p for p in parameters if p.requires_grad] + for s_param, param in zip(self.shadow_params, parameters): + s_param.sub_(one_minus_decay * (s_param - param)) + + def copy_to(self, parameters): + """ + Copy current parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. + """ + parameters = [p for p in parameters if p.requires_grad] + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + param.data.copy_(s_param.data) + + def store(self, parameters): + """ + Save the current parameters for restoring later. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + if len(self.collected_params) == 0: + raise RuntimeError( + 'No parameter values stored.' + ' Use store() before restore().') + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) + + def state_dict(self): + return dict(decay=self.decay, + num_updates=self.num_updates, + shadow_params=self.shadow_params) + + def load_state_dict(self, state_dict): + self.decay = state_dict['decay'] + self.num_updates = state_dict['num_updates'] + self.shadow_params = state_dict['shadow_params'] diff --git a/models/hf/__init__.py b/models/hf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3992cf48b303bf251b20bca419861a6be48a554f --- /dev/null +++ b/models/hf/__init__.py @@ -0,0 +1,6 @@ +"""Hugging Face config and model. + +""" + +from .configuration_udlm import UDLMConfig +from .modeling_udlm import UDLM diff --git a/models/hf/configuration_udlm.py b/models/hf/configuration_udlm.py new file mode 100644 index 0000000000000000000000000000000000000000..e8efa4e1052e00abfffa30f65184647460330faa --- /dev/null +++ b/models/hf/configuration_udlm.py @@ -0,0 +1,35 @@ +"""UDLM config for Hugging Face. + +""" + +import transformers + + +class UDLMConfig(transformers.PretrainedConfig): + """Hugging Face configuration class for UDLM.""" + model_type = "udlm" + + def __init__( + self, + vocab_size: int = 30522, # `bert-base-uncased` vocab size + model_length: int = 128, + hidden_dim: int = 768, + cond_dim: int = 128, + n_blocks: int = 12, + n_heads: int = 12, + dropout: float = 0.1, + time_conditioning: bool = True, + cfg: bool = False, # Whether model is used for Classifier-Free Guidance (CFG) + cfg_num_classes: int = -1, # Number of classes for CFG (dummy value of -1 for no CFG) + ** kwargs): + super().__init__(**kwargs) + self.vocab_size = vocab_size + self.model_length = model_length + self.hidden_dim = hidden_dim + self.cond_dim = cond_dim + self.n_blocks = n_blocks + self.n_heads = n_heads + self.dropout = dropout + self.time_conditioning = time_conditioning + self.cfg = cfg + self.cfg_num_classes = cfg_num_classes diff --git a/models/hf/modeling_udlm.py b/models/hf/modeling_udlm.py new file mode 100644 index 0000000000000000000000000000000000000000..4c18609df22b9f6ac012df01488465819792ae8a --- /dev/null +++ b/models/hf/modeling_udlm.py @@ -0,0 +1,486 @@ +"""UDLM model for Hugging Face. + +""" +import math +import typing + +import einops +import flash_attn +import flash_attn.layers.rotary +import torch +import torch.nn as nn +import torch.nn.functional as F +import transformers +from transformers import modeling_outputs + +from .configuration_udlm import UDLMConfig + +# Flags required to enable jit fusion kernels +torch._C._jit_set_profiling_mode(False) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_override_can_fuse_on_cpu(True) +torch._C._jit_override_can_fuse_on_gpu(True) + + +def bias_dropout_add_scale( + x: torch.Tensor, + bias: typing.Optional[torch.Tensor], + scale: torch.Tensor, + residual: typing.Optional[torch.Tensor], + prob: float, + training: bool) -> torch.Tensor: + if bias is not None: + out = scale * F.dropout(x + bias, p=prob, training=training) + else: + out = scale * F.dropout(x, p=prob, training=training) + + if residual is not None: + out = residual + out + return out + + +def get_bias_dropout_add_scale(training): + def _bias_dropout_add(x, bias, scale, residual, prob): + return bias_dropout_add_scale( + x, bias, scale, residual, prob, training) + + return _bias_dropout_add + + +# function overload +def modulate(x: torch.Tensor, + shift: torch.Tensor, + scale: torch.Tensor) -> torch.Tensor: + return x * (1 + scale) + shift + + +@torch.jit.script +def bias_dropout_add_scale_fused_train( + x: torch.Tensor, + bias: typing.Optional[torch.Tensor], + scale: torch.Tensor, + residual: typing.Optional[torch.Tensor], + prob: float) -> torch.Tensor: + return bias_dropout_add_scale( + x, bias, scale, residual, prob, True) + + +@torch.jit.script +def bias_dropout_add_scale_fused_inference( + x: torch.Tensor, + bias: typing.Optional[torch.Tensor], + scale: torch.Tensor, + residual: typing.Optional[torch.Tensor], + prob: float) -> torch.Tensor: + return bias_dropout_add_scale( + x, bias, scale, residual, prob, False) + + +@torch.jit.script +def modulate_fused(x: torch.Tensor, + shift: torch.Tensor, + scale: torch.Tensor) -> torch.Tensor: + return modulate(x, shift, scale) + + +class Rotary(torch.nn.Module): + def __init__(self, dim, base=10_000): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + self.seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + + def forward(self, x, seq_dim=1): + seq_len = x.shape[seq_dim] + if seq_len != self.seq_len_cached: + self.seq_len_cached = seq_len + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone()) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + # dims are: batch, seq_len, qkv, head, dim + self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1) + self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1) + # This makes the transformation on v an identity. + self.cos_cached[:,:,2,:,:].fill_(1.) + self.sin_cached[:,:,2,:,:].fill_(0.) + + return self.cos_cached, self.sin_cached + + +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(qkv, cos, sin): + cos = cos[0,:,0,0,:cos.shape[-1]//2] + sin = sin[0,:,0,0,:sin.shape[-1]//2] + return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin) + + +# function overload +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Layers # +################################################################################# +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.weight = nn.Parameter(torch.ones([dim])) + self.dim = dim + def forward(self, x): + with torch.cuda.amp.autocast(enabled=False): + x = F.layer_norm(x.float(), [self.dim]) + return x * self.weight[None,None,:] + + +def residual_linear(x, W, x_skip, residual_scale): + """x_skip + residual_scale * W @ x""" + dim_out, dim_in = W.shape[0], W.shape[1] + return torch.addmm( + x_skip.view(-1, dim_out), + x.view(-1, dim_in), + W.T, + alpha=residual_scale).view(*x.shape[:-1], dim_out) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True)) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + - math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, + torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class LabelEmbedder(nn.Module): + """Embeds class labels into vector representations.""" + def __init__(self, num_classes, cond_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, + cond_size) + self.num_classes = num_classes + + def forward(self, labels): + embeddings = self.embedding_table(labels) + return embeddings + + +################################################################################# +# Core Model # +################################################################################# + +def regular_attention_multi_headed(qkv): + # Assuming qkv is a tensor with shape [batch, seq_len, 3, num_heads, head_dim] + # where the 3 represents Q, K, V packed in that order + batch_size, seq_len, _, num_heads, head_dim = qkv.shape + # Separate Q, K, V from the packed qkv tensor + # [batch_size, seq_len, num_heads, head_dim] + q = qkv[:, :, 0, :, :] + k = qkv[:, :, 1, :, :] + v = qkv[:, :, 2, :, :] + + # Transpose and reshape Q and K for batched matrix multiplication: + # [batch_size, num_heads, seq_len, head_dim] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Compute scaled dot-product attention + # [batch_size, num_heads, seq_len, seq_len] + attention_scores = torch.matmul( + q, k.transpose(-2, -1)) / math.sqrt(head_dim) + + # Apply softmax to calculate the attention weights + attention_probs = F.softmax(attention_scores, dim=-1) + + # [batch_size, num_heads, seq_len, head_dim] + attention_output = torch.matmul(attention_probs, v) + + # [batch_size, seq_len, num_heads, head_dim] + attention_output = attention_output.transpose(1, 2) + return einops.rearrange(attention_output, + 'b s h d -> b s (h d)') + + +class DDiTBlock(nn.Module): + def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, + dropout=0.1, use_flash_attn=True): + super().__init__() + self.n_heads = n_heads + self.use_flash_attn = use_flash_attn + + self.norm1 = LayerNorm(dim) + self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False) + self.attn_out = nn.Linear(dim, dim, bias=False) + self.dropout1 = nn.Dropout(dropout) + + self.norm2 = LayerNorm(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, mlp_ratio * dim, bias=True), + nn.GELU(approximate='tanh'), + nn.Linear(mlp_ratio * dim, dim, bias=True)) + self.dropout2 = nn.Dropout(dropout) + self.dropout = dropout + + self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True) + self.adaLN_modulation.weight.data.zero_() + self.adaLN_modulation.bias.data.zero_() + + + def _get_bias_dropout_scale(self): + if self.training: + return bias_dropout_add_scale_fused_train + else: + return bias_dropout_add_scale_fused_inference + + + def forward(self, x, rotary_cos_sin, c, seqlens=None): + batch_size, seq_len = x.shape[0], x.shape[1] + + bias_dropout_scale_fn = self._get_bias_dropout_scale() + + (shift_msa, scale_msa, gate_msa, shift_mlp, + scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2) + + # attention operation + x_skip = x + x = modulate_fused(self.norm1(x), shift_msa, scale_msa) + + qkv = self.attn_qkv(x) + qkv = einops.rearrange( + qkv, + 'b s (three h d) -> b s three h d', + three=3, + h=self.n_heads) + with torch.cuda.amp.autocast(enabled=False): + cos, sin = rotary_cos_sin + qkv = apply_rotary_pos_emb( + qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)) + if seqlens is None: + cu_seqlens = torch.arange( + 0, (batch_size + 1) * seq_len, step=seq_len, + dtype=torch.int32, device=qkv.device) + else: + cu_seqlens = seqlens.cumsum(-1) + x = regular_attention_multi_headed(qkv) + + x = bias_dropout_scale_fn(self.attn_out(x), + None, + gate_msa, + x_skip, + self.dropout) + + # mlp operation + x = bias_dropout_scale_fn( + self.mlp(modulate_fused( + self.norm2(x), shift_mlp, scale_mlp)), + None, gate_mlp, x, self.dropout) + return x + + + +class EmbeddingLayer(nn.Module): + def __init__(self, dim, vocab_dim): + super().__init__() + self.embedding = nn.Parameter(torch.empty((vocab_dim, dim))) + torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5)) + + def forward(self, x): + return self.embedding[x] + + +class DDitFinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels, cond_dim): + super().__init__() + self.norm_final = LayerNorm(hidden_size) + self.linear = nn.Linear(hidden_size, out_channels) + self.linear.weight.data.zero_() + self.linear.bias.data.zero_() + + self.adaLN_modulation = nn.Linear(cond_dim, + 2 * hidden_size, + bias=True) + self.adaLN_modulation.weight.data.zero_() + self.adaLN_modulation.bias.data.zero_() + + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2) + x = modulate_fused(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class DITBackbone(nn.Module): + def __init__( + self, + config: UDLMConfig): + super().__init__() + + self.config = config + self.vocab_size = config.vocab_size + + self.vocab_embed = EmbeddingLayer( + config.hidden_dim, + config.vocab_size) + self.sigma_map = TimestepEmbedder( + config.cond_dim) + if config.cfg: + self.cond_map = LabelEmbedder( + config.cfg_num_classes + 1, # +1 for mask + config.cond_dim) + else: + self.cond_map = None + self.rotary_emb = Rotary( + config.hidden_dim // config.n_heads) + + blocks = [] + for _ in range(config.n_blocks): + blocks.append(DDiTBlock(config.hidden_dim, + config.n_heads, + config.cond_dim, + dropout=config.dropout)) + self.blocks = nn.ModuleList(blocks) + + self.output_layer = DDitFinalLayer( + config.hidden_dim, + config.vocab_size, + config.cond_dim) + self.precision = torch.float32 + + def _get_bias_dropout_scale(self): + if self.training: + return bias_dropout_add_scale_fused_train + else: + return bias_dropout_add_scale_fused_inference + + def forward( + self, + indices, + sigma, + cond=None, + x_emb=None, + output_hidden_states=False): + if not self.config.time_conditioning: + sigma = torch.zeros_like(sigma) + all_hidden_states = [] + + c = F.silu(self.sigma_map(sigma)) + if cond is not None: + if self.cond_map is None: + raise ValueError("Conditioning variable provided, " + "but Model was not initialized " + "with condition embedding layer.") + else: + c = c + F.silu(self.cond_map(cond)) + + if x_emb is None: + x = self.vocab_embed(indices) + if output_hidden_states: + all_hidden_states.append(x) + + rotary_cos_sin = self.rotary_emb(x) + + with torch.cuda.amp.autocast(dtype=self.precision): + for i in range(len(self.blocks)): + x = self.blocks[i](x, rotary_cos_sin, c, + seqlens=None) + if output_hidden_states: + all_hidden_states.append(x) + else: + x = x_emb + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + logits = self.output_layer(x, c) + return logits, all_hidden_states + +class UDLM(transformers.PreTrainedModel): + """HF-compatible model.""" + config_class = UDLMConfig + base_model_prefix = "udlm" + + def __init__( + self, + config: UDLMConfig): + super().__init__(config) + self.backbone = DITBackbone(config) + + def forward( + self, + input_ids: torch.LongTensor = None, + timesteps: torch.FloatTensor = None, + cond: torch.LongTensor = None, + output_hidden_states: typing.Optional[bool] = None, + return_dict: typing.Optional[bool] = None, + **kwargs, + ) -> typing.Union[ + torch.Tensor, typing.Tuple, + modeling_outputs.MaskedLMOutput]: + """HF-compatible forward method.""" + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict \ + if return_dict is not None \ + else self.config.use_return_dict + + logits, all_hidden_states = self.backbone( + indices=input_ids, + sigma=timesteps, + cond=cond, + output_hidden_states=output_hidden_states, + **kwargs, + ) + if return_dict: + return modeling_outputs.MaskedLMOutput( + logits=logits, + hidden_states=all_hidden_states if output_hidden_states else None, + loss=None + ) + elif output_hidden_states: + return logits, all_hidden_states + else: + return logits diff --git a/models/hf/push_to_hf_hub.ipynb b/models/hf/push_to_hf_hub.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..f0d495b3026d12a9f77e008543b8e9f5522a2363 --- /dev/null +++ b/models/hf/push_to_hf_hub.ipynb @@ -0,0 +1,244 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c8b6fdc7-e525-44de-bc51-0eb838d1d1af", + "metadata": {}, + "source": [ + "### Push UDLM to Hub" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ce3f3e3-9bc0-45cc-8e07-e507378df58a", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import huggingface_hub\n", + "import torch\n", + "import transformers\n", + "\n", + "from models.hf import UDLMConfig\n", + "from models.hf import UDLM\n", + "from models.ema import ExponentialMovingAverage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "684e2f51-a993-405a-8447-dc7e1445465c", + "metadata": {}, + "outputs": [], + "source": [ + "if os.path.exists(os.path.join(os.environ['HF_HOME'], 'token')):\n", + " with open(os.path.join(os.environ['HF_HOME'], 'token'), 'r') as f:\n", + " token = f.read().strip()\n", + "else:\n", + " token = None\n", + "huggingface_hub.login(token=token)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65dc9b6e-5686-438f-935f-b4928a4e2c84", + "metadata": {}, + "outputs": [], + "source": [ + "UDLMConfig.register_for_auto_class()\n", + "UDLM.register_for_auto_class('AutoModelForMaskedLM')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35732a5f-676b-4a7b-b21a-470db0e498d9", + "metadata": {}, + "outputs": [], + "source": [ + "device = 'cuda'\n", + "# 'bert-base-uncased' for LM1B\n", + "# 'yairschiff/qm9-tokenizer' for QM9\n", + "tokenizer = transformers.AutoTokenizer.from_pretrained('bert-base-uncased', trust_remote_code=True)\n", + "# tokenizer = transformers.AutoTokenizer.from_pretrained('yairschiff/qm9-tokenizer', trust_remote_code=True)\n", + "\n", + "# 'kuleshov-group/udlm-lm1b' for LM1B\n", + "# 'kuleshov-group/udlm-qm9' for QM9\n", + "name_or_path = 'kuleshov-group/udlm-lm1b'\n", + "# name_or_path = 'kuleshov-group/udlm-qm9'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f246fa0c-077d-432c-98fa-f64da7c17cea", + "metadata": {}, + "outputs": [], + "source": [ + "config = UDLMConfig(\n", + " vocab_size=tokenizer.vocab_size,\n", + " model_length=128,\n", + " hidden_dim=768,\n", + " cond_dim=128,\n", + " n_blocks=12, \n", + " n_heads=12,\n", + " dropout=0.1,\n", + " time_conditioning=True,\n", + " cfg=False,\n", + " cfg_num_classes=-1,\n", + " return_dict=False\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9bcef091-fad2-4e84-a1bc-03c629572d24", + "metadata": {}, + "outputs": [], + "source": [ + "model = UDLM(config)\n", + "ema = ExponentialMovingAverage(\n", + " model.backbone.parameters(),\n", + " decay=0.0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a5f8cb6-0021-46bb-a9bc-3db36a28f111", + "metadata": {}, + "outputs": [], + "source": [ + "model.config._name_or_path = name_or_path\n", + "model.config.auto_map = {\n", + " 'AutoConfig': f'{name_or_path}--configuraction_udlm.UDLMConfig',\n", + " 'AutoModelForMaskedLM': f'{name_or_path}--modeling_udlm.UDLM',\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0062a090-b72b-4b7c-9da9-949e5d03e8da", + "metadata": {}, + "outputs": [], + "source": [ + "ckpt_path = ''\n", + "ckpt = torch.load(ckpt_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "974d3128-5432-4c4f-ae48-687b25f2b29b", + "metadata": {}, + "outputs": [], + "source": [ + "ema.load_state_dict(ckpt['ema'])\n", + "ema.copy_to(model.backbone.parameters())\n", + "model = model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5f111403-3bda-4f5b-b98e-ccb4c7cb93ad", + "metadata": {}, + "outputs": [], + "source": [ + "# Confirm EMA params loaded\n", + "for c, m in zip(ema.shadow_params, ckpt['ema']['shadow_params']):\n", + " if not torch.allclose(c.to(device), m.to(device)):\n", + " print('Issue with EMA!')\n", + "\n", + "for c, m in zip(ema.shadow_params, model.parameters()):\n", + " if not torch.allclose(c.to(device), m.to(device)):\n", + " print('Issue with EMA!')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d172396-b837-4384-95ec-7e6704c89b79", + "metadata": {}, + "outputs": [], + "source": [ + "model.push_to_hub(name_or_path, private=False)" + ] + }, + { + "cell_type": "markdown", + "id": "1ccf941f-e01a-4233-be07-2cb1dd3004ee", + "metadata": {}, + "source": [ + "### Test Model from Hub" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53091198-af84-40fc-84a2-802779e8e75e", + "metadata": {}, + "outputs": [], + "source": [ + "model_test = transformers.AutoModelForMaskedLM.from_pretrained(name_or_path, trust_remote_code=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14fbb458-444e-4127-b821-d46a040d97a0", + "metadata": {}, + "outputs": [], + "source": [ + "input_ids = torch.randint(10, size=(2, 10)).to(device)\n", + "model_test = model_test.to(device)\n", + "model_test.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee5d3580-5fbb-44a4-ba92-5c92bb6a208b", + "metadata": {}, + "outputs": [], + "source": [ + "print(model_test(input_ids, torch.zeros(2,).to(device)).shape)\n", + "print(model_test(input_ids, torch.zeros(2,).to(device)).mean())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be448f86-7d93-4251-9985-466f02327629", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.20" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/models/unet.py b/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..af4d851df5b30ab31a13b60f034a79678e5c2801 --- /dev/null +++ b/models/unet.py @@ -0,0 +1,567 @@ +import torch +import torch.nn as nn +import math +import torch.nn.functional as F +import numpy as np +import omegaconf + +import transformers +from einops import rearrange +from .dit import LabelEmbedder, EmbeddingLayer + + +# From https://github.com/yang-song/score_sde_pytorch/ which is from +# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py +def transformer_timestep_embedding(timesteps, embedding_dim, max_positions=10000): + assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 + half_dim = embedding_dim // 2 + # magic number 10000 is from transformers + emb = math.log(max_positions) / (half_dim - 1) + # emb = math.log(2.) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) + # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] + # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = F.pad(emb, (0, 1), mode='constant') + assert emb.shape == (timesteps.shape[0], embedding_dim) + return emb + + +# Code modified from https://github.com/yang-song/score_sde_pytorch +def variance_scaling(scale, mode, distribution, + in_axis=1, out_axis=0, + dtype=torch.float32, + device='cpu'): + """Ported from JAX. """ + + def _compute_fans(shape, in_axis=1, out_axis=0): + receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] + fan_in = shape[in_axis] * receptive_field_size + fan_out = shape[out_axis] * receptive_field_size + return fan_in, fan_out + + def init(shape, dtype=dtype, device=device): + fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) + if mode == "fan_in": + denominator = fan_in + elif mode == "fan_out": + denominator = fan_out + elif mode == "fan_avg": + denominator = (fan_in + fan_out) / 2 + else: + raise ValueError( + "invalid mode for variance scaling initializer: {}".format(mode)) + variance = scale / denominator + if distribution == "normal": + return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) + elif distribution == "uniform": + return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance) + else: + raise ValueError("invalid distribution for variance scaling initializer") + + return init + + +def default_init(scale=1.): + """The same initialization used in DDPM.""" + scale = 1e-10 if scale == 0 else scale + return variance_scaling(scale, 'fan_avg', 'uniform') + + +class NiN(nn.Module): + def __init__(self, in_ch, out_ch, init_scale=0.1): + super().__init__() + self.W = nn.Parameter(default_init(scale=init_scale)((in_ch, out_ch)), requires_grad=True) + self.b = nn.Parameter(torch.zeros(out_ch), requires_grad=True) + + def forward(self, x, # ["batch", "in_ch", "H", "W"] + ): + + x = x.permute(0, 2, 3, 1) + # x (batch, H, W, in_ch) + y = torch.einsum('bhwi,ik->bhwk', x, self.W) + self.b + # y (batch, H, W, out_ch) + return y.permute(0, 3, 1, 2) + +class AttnBlock(nn.Module): + """Channel-wise self-attention block.""" + def __init__(self, channels, skip_rescale=True): + super().__init__() + self.skip_rescale = skip_rescale + self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels//4, 32), + num_channels=channels, eps=1e-6) + self.NIN_0 = NiN(channels, channels) + self.NIN_1 = NiN(channels, channels) + self.NIN_2 = NiN(channels, channels) + self.NIN_3 = NiN(channels, channels, init_scale=0.) + + def forward(self, x, # ["batch", "channels", "H", "W"] + ): + + B, C, H, W = x.shape + h = self.GroupNorm_0(x) + q = self.NIN_0(h) + k = self.NIN_1(h) + v = self.NIN_2(h) + + w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5)) + w = torch.reshape(w, (B, H, W, H * W)) + w = F.softmax(w, dim=-1) + w = torch.reshape(w, (B, H, W, H, W)) + h = torch.einsum('bhwij,bcij->bchw', w, v) + h = self.NIN_3(h) + + if self.skip_rescale: + return (x + h) / np.sqrt(2.) + else: + return x + h + + +class ResBlock(nn.Module): + def __init__(self, in_ch, out_ch, temb_dim=None, dropout=0.1, skip_rescale=True): + super().__init__() + + self.in_ch = in_ch + self.out_ch = out_ch + + self.skip_rescale = skip_rescale + + self.act = nn.functional.silu + self.groupnorm0 = nn.GroupNorm( + num_groups=min(in_ch // 4, 32), + num_channels=in_ch, eps=1e-6 + ) + self.conv0 = nn.Conv2d( + in_ch, out_ch, kernel_size=3, padding=1 + ) + + if temb_dim is not None: + self.dense0 = nn.Linear(temb_dim, out_ch) + nn.init.zeros_(self.dense0.bias) + + + self.groupnorm1 = nn.GroupNorm( + num_groups=min(out_ch // 4, 32), + num_channels=out_ch, eps=1e-6 + ) + self.dropout0 = nn.Dropout(dropout) + + self.conv1 = nn.Conv2d( + out_ch, out_ch, kernel_size=3, padding=1 + ) + if out_ch != in_ch: + self.nin = NiN(in_ch, out_ch) + + def forward(self, x, # ["batch", "in_ch", "H", "W"] + temb=None, # ["batch", "temb_dim"] + ): + + assert x.shape[1] == self.in_ch + + h = self.groupnorm0(x) + h = self.act(h) + h = self.conv0(h) + + if temb is not None: + h += self.dense0(self.act(temb))[:, :, None, None] + + h = self.groupnorm1(h) + h = self.act(h) + h = self.dropout0(h) + h = self.conv1(h) + if h.shape[1] != self.in_ch: + x = self.nin(x) + + assert x.shape == h.shape + + if self.skip_rescale: + return (x + h) / np.sqrt(2.) + else: + return x + h + +class Downsample(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv = nn.Conv2d(channels, channels, kernel_size=3, + stride=2, padding=0) + + def forward(self, x, # ["batch", "ch", "inH", "inW"] + ): + B, C, H, W = x.shape + x = nn.functional.pad(x, (0, 1, 0, 1)) + x= self.conv(x) + + assert x.shape == (B, C, H // 2, W // 2) + return x + +class Upsample(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + + def forward(self, x, # ["batch", "ch", "inH", "inW"] + ): + B, C, H, W = x.shape + h = F.interpolate(x, (H*2, W*2), mode='nearest') + h = self.conv(h) + + assert h.shape == (B, C, H*2, W*2) + return h + + +class UNet(nn.Module): + def __init__(self, config, vocab_size=None): + super().__init__() + if type(config) == dict: + config = omegaconf.OmegaConf.create(config) + self.ch = config.model.ch + self.num_res_blocks = config.model.num_res_blocks + self.num_scales = config.model.num_scales + self.ch_mult = config.model.ch_mult + assert self.num_scales == len(self.ch_mult) + self.input_channels = config.model.input_channels + self.output_channels = 2 * config.model.input_channels + self.scale_count_to_put_attn = config.model.scale_count_to_put_attn + self.data_min_max = [0, vocab_size] # config.model.data_min_max # tuple of min and max value of input so it can be rescaled to [-1, 1] + self.dropout = config.model.dropout + self.skip_rescale = config.model.skip_rescale + self.time_conditioning = config.model.time_conditioning # Whether to add in time embeddings + self.time_scale_factor = config.model.time_scale_factor # scale to make the range of times be 0 to 1000 + self.time_embed_dim = config.model.time_embed_dim + self.vocab_size = vocab_size + + self.size = config.model.size + self.length = config.model.length + + # truncated logistic + self.fix_logistic = config.model.fix_logistic + + self.act = nn.functional.silu + + if self.time_conditioning: + self.temb_modules = [] + self.temb_modules.append(nn.Linear(self.time_embed_dim, self.time_embed_dim*4)) + nn.init.zeros_(self.temb_modules[-1].bias) + self.temb_modules.append(nn.Linear(self.time_embed_dim*4, self.time_embed_dim*4)) + nn.init.zeros_(self.temb_modules[-1].bias) + self.temb_modules = nn.ModuleList(self.temb_modules) + + self.expanded_time_dim = 4 * self.time_embed_dim if self.time_conditioning else None + + self.input_conv = nn.Conv2d( + in_channels=self.input_channels, out_channels=self.ch, + kernel_size=3, padding=1 + ) + + h_cs = [self.ch] + in_ch = self.ch + + # Downsampling + self.downsampling_modules = [] + + for scale_count in range(self.num_scales): + for res_count in range(self.num_res_blocks): + out_ch = self.ch * self.ch_mult[scale_count] + self.downsampling_modules.append( + ResBlock(in_ch, out_ch, temb_dim=self.expanded_time_dim, + dropout=self.dropout, skip_rescale=self.skip_rescale) + ) + in_ch = out_ch + h_cs.append(in_ch) + if scale_count == self.scale_count_to_put_attn: + self.downsampling_modules.append( + AttnBlock(in_ch, skip_rescale=self.skip_rescale) + ) + + if scale_count != self.num_scales - 1: + self.downsampling_modules.append(Downsample(in_ch)) + h_cs.append(in_ch) + + self.downsampling_modules = nn.ModuleList(self.downsampling_modules) + + # Middle + self.middle_modules = [] + + self.middle_modules.append( + ResBlock(in_ch, in_ch, temb_dim=self.expanded_time_dim, + dropout=self.dropout, skip_rescale=self.skip_rescale) + ) + self.middle_modules.append( + AttnBlock(in_ch, skip_rescale=self.skip_rescale) + ) + self.middle_modules.append( + ResBlock(in_ch, in_ch, temb_dim=self.expanded_time_dim, + dropout=self.dropout, skip_rescale=self.skip_rescale) + ) + self.middle_modules = nn.ModuleList(self.middle_modules) + + # Upsampling + self.upsampling_modules = [] + + for scale_count in reversed(range(self.num_scales)): + for res_count in range(self.num_res_blocks+1): + out_ch = self.ch * self.ch_mult[scale_count] + self.upsampling_modules.append( + ResBlock(in_ch + h_cs.pop(), + out_ch, + temb_dim=self.expanded_time_dim, + dropout=self.dropout, + skip_rescale=self.skip_rescale + ) + ) + in_ch = out_ch + + if scale_count == self.scale_count_to_put_attn: + self.upsampling_modules.append( + AttnBlock(in_ch, skip_rescale=self.skip_rescale) + ) + if scale_count != 0: + self.upsampling_modules.append(Upsample(in_ch)) + + self.upsampling_modules = nn.ModuleList(self.upsampling_modules) + + assert len(h_cs) == 0 + + # output + self.output_modules = [] + + self.output_modules.append( + nn.GroupNorm(min(in_ch//4, 32), in_ch, eps=1e-6) + ) + + self.output_modules.append( + nn.Conv2d(in_ch, self.output_channels, kernel_size=3, padding=1) + ) + self.output_modules = nn.ModuleList(self.output_modules) + + if config.training.guidance: + self.cond_map = LabelEmbedder( + config.data.num_classes + 1, # +1 for mask + self.time_embed_dim*4) + else: + self.cond_map = None + + def _center_data(self, x): + out = (x - self.data_min_max[0]) / (self.data_min_max[1] - self.data_min_max[0]) # [0, 1] + return 2 * out - 1 # to put it in [-1, 1] + + def _time_embedding(self, timesteps): + if self.time_conditioning: + temb = transformer_timestep_embedding( + timesteps * self.time_scale_factor, self.time_embed_dim + ) + temb = self.temb_modules[0](temb) + temb = self.temb_modules[1](self.act(temb)) + else: + temb = None + + return temb + + def _do_input_conv(self, h): + h = self.input_conv(h) + hs = [h] + return h, hs + + def _do_downsampling(self, h, hs, temb): + m_idx = 0 + for scale_count in range(self.num_scales): + for res_count in range(self.num_res_blocks): + h = self.downsampling_modules[m_idx](h, temb) + m_idx += 1 + if scale_count == self.scale_count_to_put_attn: + h = self.downsampling_modules[m_idx](h) + m_idx += 1 + hs.append(h) + + if scale_count != self.num_scales - 1: + h = self.downsampling_modules[m_idx](h) + hs.append(h) + m_idx += 1 + + assert m_idx == len(self.downsampling_modules) + + return h, hs + + def _do_middle(self, h, temb): + m_idx = 0 + h = self.middle_modules[m_idx](h, temb) + m_idx += 1 + h = self.middle_modules[m_idx](h) + m_idx += 1 + h = self.middle_modules[m_idx](h, temb) + m_idx += 1 + + assert m_idx == len(self.middle_modules) + + return h + + def _do_upsampling(self, h, hs, temb): + m_idx = 0 + for scale_count in reversed(range(self.num_scales)): + for res_count in range(self.num_res_blocks+1): + h = self.upsampling_modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) + m_idx += 1 + + if scale_count == self.scale_count_to_put_attn: + h = self.upsampling_modules[m_idx](h) + m_idx += 1 + + if scale_count != 0: + h = self.upsampling_modules[m_idx](h) + m_idx += 1 + + assert len(hs) == 0 + assert m_idx == len(self.upsampling_modules) + + return h + + def _do_output(self, h): + + h = self.output_modules[0](h) + h = self.act(h) + h = self.output_modules[1](h) + + return h + + def _logistic_output_res(self, + h, # ["B", "twoC", "H", "W"] + centered_x_in, # ["B", "C", "H", "W"] + ): + B, twoC, H, W = h.shape + C = twoC//2 + h[:, 0:C, :, :] = torch.tanh(centered_x_in + h[:, 0:C, :, :]) + return h + + def _log_minus_exp(self, a, b, eps=1e-6): + """ + Compute log (exp(a) - exp(b)) for (b b c h w", h=img_size, w=img_size, c=3) + h = self._center_data(h) + centered_x_in = h + + temb = self._time_embedding(timesteps) + if cond is not None: + if self.cond_map is None: + raise ValueError("Conditioning variable provided, " + "but Model was not initialized " + "with condition embedding layer.") + else: + assert cond.shape == (x.shape[0],) + temb = temb + self.cond_map(cond) + + h, hs = self._do_input_conv(h) + + h, hs = self._do_downsampling(h, hs, temb) + + h = self._do_middle(h, temb) + + h = self._do_upsampling(h, hs, temb) + + h = self._do_output(h) + + # h (B, 2*C, H, W) + h = self._logistic_output_res(h, centered_x_in) + h = self._truncated_logistic_output(h) # (B, D, S) + + return h + + +class UNetConfig(transformers.PretrainedConfig): + """Hugging Face configuration class for MDLM.""" + model_type = "unet" + + def __init__( + self, + ch: int = 128, + num_res_blocks: int = 2, + num_scales: int = 4, + ch_mult: list = [1, 2, 2, 2], + input_channels: int = 3, + output_channels: int = 3, + scale_count_to_put_attn: int = 1, + data_min_max: list = [0, 255], # tuple of min and max value of input so it can be rescaled to [-1, 1] + dropout: float = 0.1, + skip_rescale: bool = True, + time_conditioning: bool = True, # Whether to add in time embeddings + time_scale_factor: float = 1000, # scale to make the range of times be 0 to 1000 + time_embed_dim: int = 128, + fix_logistic: bool = False, + vocab_size: int = 256, + size: int = 1024, + guidance_classifier_free: bool = False, + guidance_num_classes: int = -1, + cond_dim: int = -1, + length: int = 3072, # 3x32x32 + **kwargs): + + super().__init__(**kwargs) + self.ch = ch + self.num_res_blocks = num_res_blocks + self.num_scales = num_scales + self.ch_mult = ch_mult + self.input_channels = input_channels + self.output_channels = vocab_size + self.scale_count_to_put_attn = scale_count_to_put_attn + self.data_min_max = data_min_max # tuple of min and max value of input so it can be rescaled to [-1, 1] + self.dropout = dropout + self.skip_rescale = skip_rescale + self.time_conditioning = time_conditioning # Whether to add in time embeddings + self.time_scale_factor = time_scale_factor # scale to make the range of times be 0 to 1000 + self.time_embed_dim = time_embed_dim + self.fix_logistic = fix_logistic + + self.vocab_size = vocab_size + self.size = size + self.guidance_classifier_free = guidance_classifier_free + self.guidance_num_classes = guidance_num_classes + self.cond_dim = cond_dim + self.length = length diff --git a/scripts/eval_amazon_polarity_gen_ppl.sh b/scripts/eval_amazon_polarity_gen_ppl.sh new file mode 100644 index 0000000000000000000000000000000000000000..f1d3dfd7a37e02907b68c9778bfff467f4747cd6 --- /dev/null +++ b/scripts/eval_amazon_polarity_gen_ppl.sh @@ -0,0 +1,94 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=32000 # server memory requested (per node) +#SBATCH -t 96:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +sbatch \ + --export=ALL,MODEL=${MODEL} \ + --job-name=eval_amazon_polarity_gen_ppl_${MODEL} \ + eval_amazon_polarity_gen_ppl.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh || exit +export HYDRA_FULL_ERROR=1 + +# Expecting: +# - MODEL (choices: ar, mdlm, udlm) +# - SAMPLING_STEPS (optional: default = 128) +# - SEED (optional: default = 1) + +if [ -z "${MODEL}" ]; then + echo "MODEL is not set" + exit 1 +fi +if [ -z "${SAMPLING_STEPS}" ]; then + SAMPLING_STEPS=128 +fi +if [ -z "${SEED}" ]; then + SEED=1 +fi + +if [ "${MODEL}" = "ar" ]; then + parameterization="ar" + diffusion="absorbing_state" + TRAIN_T=0 + time_conditioning=False + sampling_use_cache=False + CKPT="${PWD}/outputs/amazon_polarity/ar" +elif [ "${MODEL}" = "mdlm" ]; then + parameterization="subs" + diffusion="absorbing_state" + TRAIN_T=0 + time_conditioning=False + sampling_use_cache=True + CKPT="${PWD}/outputs/amazon_polarity/mdlm" +elif [ "${MODEL}" = "udlm" ]; then + parameterization="d3pm" + diffusion="uniform" + TRAIN_T=0 + time_conditioning=True + sampling_use_cache=False + CKPT="${PWD}/outputs/amazon_polarity/udlm" +else + echo "Invalid MODEL: ${MODEL}" + exit 1 +fi +generated_seqs_path="${CKPT}/samples-amazon_polarity-gen-ppl-eval-_T-${SAMPLING_STEPS}_seed-${SEED}.json" + +# shellcheck disable=SC2086 +python -u -m main \ + hydra.output_subdir=null \ + hydra.run.dir="${CKPT}" \ + hydra/job_logging=disabled \ + hydra/hydra_logging=disabled \ + seed=${SEED} \ + mode="gen_ppl_eval" \ + eval.checkpoint_path="${CKPT}/checkpoints/best.ckpt" \ + data=amazon_polarity \ + backbone=dit \ + model=small \ + model.length=128 \ + training.guidance=null \ + parameterization=${parameterization} \ + diffusion=${diffusion} \ + time_conditioning=${time_conditioning} \ + T=${TRAIN_T} \ + sampling.num_sample_batches=32 \ + sampling.batch_size=32 \ + sampling.steps=${SAMPLING_STEPS} \ + sampling.use_cache=${sampling_use_cache} \ + eval.generated_samples_path=${generated_seqs_path} \ + +eval.generative_ppl_model_name_or_path="gpt2-large" diff --git a/scripts/eval_amazon_polarity_guidance.sh b/scripts/eval_amazon_polarity_guidance.sh new file mode 100644 index 0000000000000000000000000000000000000000..74c16a3a4059134f921ca80e6f0d07ab666321d4 --- /dev/null +++ b/scripts/eval_amazon_polarity_guidance.sh @@ -0,0 +1,217 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=32000 # server memory requested (per node) +#SBATCH -t 24:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +GUIDANCE= +... additional args for each guidance method ... +sbatch \ + --export=ALL,MODEL=${MODEL},GUIDANCE=${GUIDANCE},... \ + --job-name=eval_amazon_polarity_${GUIDANCE}_${MODEL} \ + eval_amazon_polarity_guidance.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh || exit +export HYDRA_FULL_ERROR=1 + +# Expecting: +# - MODEL (choices: ar, mdlm, udlm) +# - GUIDANCE (each method has its own required args) +# - CONDITION (optional: default = 1) +# - SAMPLING_STEPS (optional: default = 128) +# - SEED (optional: default = 1) + +if [ -z "${MODEL}" ]; then + echo "MODEL is not set" + exit 1 +fi +if [ -z "${GUIDANCE}" ]; then + echo "GUIDANCE is not set" + exit 1 +fi +if [ -z "${CONDITION}" ]; then + CONDITION=1 +fi +if [ -z "${SAMPLING_STEPS}" ]; then + SAMPLING_STEPS=128 +fi +if [ -z "${SEED}" ]; then + SEED=1 +fi + +# CKPT below is unconditional model (will be overridden if GUIDANCE = "cfg") +if [ "${MODEL}" = "ar" ]; then + parameterization="ar" + diffusion="absorbing_state" + TRAIN_T=0 + time_conditioning=False + sampling_use_cache=False + CKPT="${PWD}/outputs/amazon_polarity/ar" +elif [ "${MODEL}" = "mdlm" ]; then + parameterization="subs" + diffusion="absorbing_state" + TRAIN_T=0 + time_conditioning=False + sampling_use_cache=True + CKPT="${PWD}/outputs/amazon_polarity/mdlm" +elif [ "${MODEL}" = "udlm" ]; then + parameterization="d3pm" + diffusion="uniform" + TRAIN_T=0 + time_conditioning=True + sampling_use_cache=False + CKPT="${PWD}/outputs/amazon_polarity/udlm" +else + echo "Invalid MODEL: ${MODEL}" + exit 1 +fi + + +guidance_args="guidance=${GUIDANCE} guidance.condition=${CONDITION}" +###### CFG ###### +if [ "${GUIDANCE}" == "cfg" ]; then + # Expecting: + # - GAMMA + if [ -z "${GAMMA}" ]; then + echo "GAMMA is not set" + exit 1 + fi + if [ "${MODEL}" = "ar" ]; then + CKPT="${PWD}/outputs/amazon_polarity/ar" + elif [ "${MODEL}" = "mdlm" ]; then + CKPT="${PWD}/outputs/amazon_polarity/mdlm" + elif [ "${MODEL}" = "udlm" ]; then + CKPT="${PWD}/outputs/amazon_polarity/udlm" + fi + guidance_args="${guidance_args} guidance.gamma=${GAMMA}" + results_csv_path="${CKPT}/amazon_polarity-eval-${GUIDANCE}_T-${SAMPLING_STEPS}_gamma-${GAMMA}_seed-${SEED}.csv" + generated_seqs_path="${CKPT}/samples-amazon_polarity-eval-${GUIDANCE}_T-${SAMPLING_STEPS}_gamma-${GAMMA}_seed-${SEED}.json" +###### FUDGE / CBG ###### +elif [ "${GUIDANCE}" = "fudge" ] || [ "${GUIDANCE}" = "cbg" ]; then + # Expecting: + # - GAMMA + # - USE_APPROX (for cbg) + if [ -z "${GAMMA}" ]; then + echo "GAMMA is not set" + exit 1 + fi + if [ "${MODEL}" = "ar" ]; then + CLASS_CKPT="${PWD}/outputs/amazon_polarity/fudge_classifier" + elif [ "${MODEL}" = "mdlm" ]; then + CLASS_CKPT="${PWD}/outputs/amazon_polarity/classifier/absorbing_state_T-0" + elif [ "${MODEL}" = "udlm" ]; then + CLASS_CKPT="${PWD}/outputs/amazon_polarity/classifier/uniform_T-0" + fi + guidance_args="${guidance_args} classifier_model=tiny-classifier classifier_backbone=dit guidance.classifier_checkpoint_path=${CLASS_CKPT}/checkpoints/best.ckpt guidance.gamma=${GAMMA}" + if [ "${GUIDANCE}" = "fudge" ] || [ "${GUIDANCE}" = "cbg_topk" ]; then + guidance_args="${guidance_args} guidance.topk=200 classifier_model.pooling=no_pooling" # Use full vocab size for topk + fi + if [ "${GUIDANCE}" = "cbg" ]; then + if [ -z "${USE_APPROX}" ]; then + echo "USE_APPROX is not set" + exit 1 + fi + guidance_args="${guidance_args} guidance.use_approx=${USE_APPROX}" + results_csv_path="${CKPT}/amazon_polarity-eval-${GUIDANCE}_approx-${USE_APPROX}_T-${SAMPLING_STEPS}_gamma-${GAMMA}_seed-${SEED}.csv" + generated_seqs_path="${CKPT}/samples-amazon_polarity-eval-${GUIDANCE}_approx-${USE_APPROX}_T-${SAMPLING_STEPS}_gamma-${GAMMA}_seed-${SEED}.json" + else + results_csv_path="${CKPT}/amazon_polarity-eval-${GUIDANCE}_T-${SAMPLING_STEPS}_gamma-${GAMMA}_seed-${SEED}.csv" + generated_seqs_path="${CKPT}/samples-amazon_polarity-eval-${GUIDANCE}_T-${SAMPLING_STEPS}_gamma-${GAMMA}_seed-${SEED}.json" + fi +###### PPLM / NOS ###### +elif [ "${GUIDANCE}" = "pplm" ] || [ "${GUIDANCE}" = "nos" ]; then + if [ "${GUIDANCE}" = "pplm" ]; then + # Expecting: + # - NUM_PPLM_STEPS + # - PPLM_STEP_SIZE + # - PPLM_STABILITY_COEF + if [ -z "${NUM_PPLM_STEPS}" ]; then + echo "NUM_PPLM_STEPS is not set" + exit 1 + fi + if [ -z "${PPLM_STEP_SIZE}" ]; then + echo "PPLM_STEP_SIZE is not set" + exit 1 + fi + if [ -z "${PPLM_STABILITY_COEF}" ]; then + echo "PPLM_STABILITY_COEF is not set" + exit 1 + fi + guidance_args="${guidance_args} guidance.num_pplm_steps=${NUM_PPLM_STEPS} guidance.pplm_step_size=${PPLM_STEP_SIZE} guidance.pplm_stability_coef=${PPLM_STABILITY_COEF}" + results_csv_path="${CKPT}/amazon_polarity-eval-${GUIDANCE}_T-${SAMPLING_STEPS}_NUM_PPLM_STEPS-${NUM_PPLM_STEPS}_PPLM_STEP_SIZE-${PPLM_STEP_SIZE}_PPLM_STABILITY_COEF-${PPLM_STABILITY_COEF}_seed-${SEED}.csv" + generated_seqs_path="${CKPT}/samples_amazon_polarity-eval-${GUIDANCE}_T-${SAMPLING_STEPS}_NUM_PPLM_STEPS-${NUM_PPLM_STEPS}_PPLM_STEP_SIZE-${PPLM_STEP_SIZE}_PPLM_STABILITY_COEF-${PPLM_STABILITY_COEF}_seed-${SEED}.json" + else + # Expecting: + # - NUM_NOS_STEPS + # - NOS_STEP_SIZE + # - NOS_STABILITY_COEF + if [ -z "${NUM_NOS_STEPS}" ]; then + echo "NUM_NOS_STEPS is not set" + exit 1 + fi + if [ -z "${NOS_STEP_SIZE}" ]; then + echo "NOS_STEP_SIZE is not set" + exit 1 + fi + if [ -z "${NOS_STABILITY_COEF}" ]; then + echo "NOS_STABILITY_COEF is not set" + exit 1 + fi + guidance_args="${guidance_args} guidance.num_nos_steps=${NUM_NOS_STEPS} guidance.nos_step_size=${NOS_STEP_SIZE} guidance.nos_stability_coef=${NOS_STABILITY_COEF}" + results_csv_path="${CKPT}/amazon_polarity-eval-${GUIDANCE}_T-${SAMPLING_STEPS}_NUM_NOS_STEPS-${NUM_NOS_STEPS}_NOS_STEP_SIZE-${NOS_STEP_SIZE}_NOS_STABILITY_COEF-${NOS_STABILITY_COEF}_seed-${SEED}.csv" + generated_seqs_path="${CKPT}/samples_amazon_polarity-eval-${GUIDANCE}_T-${SAMPLING_STEPS}_NUM_NOS_STEPS-${NUM_NOS_STEPS}_NOS_STEP_SIZE-${NOS_STEP_SIZE}_NOS_STABILITY_COEF-${NOS_STABILITY_COEF}_seed-${SEED}.json" + fi + + if [ "${MODEL}" = "ar" ]; then + CLASS_CKPT="${PWD}/outputs/amazon_polarity/pplm_classifier/ar_lr-2e-3" + elif [ "${MODEL}" = "mdlm" ]; then + CLASS_CKPT="${PWD}/outputs/amazon_polarity/pplm_classifier/mdlm_lr-2e-3" + elif [ "${MODEL}" = "udlm" ]; then + CLASS_CKPT="${PWD}/outputs/amazon_polarity/pplm_classifier/udlm_lr-2e-3" + fi + guidance_args="${guidance_args} classifier_model=small-classifier classifier_backbone=dit guidance.classifier_checkpoint_path=${CLASS_CKPT}/checkpoints/best.ckpt" +else + echo "Invalid GUIDANCE: ${GUIDANCE}" + exit 1 +fi + +# shellcheck disable=SC2086 +python -u guidance_eval/amazon_polarity_eval.py \ + hydra.output_subdir=null \ + hydra.run.dir="${CKPT}" \ + hydra/job_logging=disabled \ + hydra/hydra_logging=disabled \ + seed=${SEED} \ + mode=amazon_polarity_eval \ + eval.checkpoint_path="${CKPT}/checkpoints/best.ckpt" \ + data=amazon_polarity \ + backbone=dit \ + model=small \ + model.length=128 \ + training.guidance=null \ + parameterization=${parameterization} \ + diffusion=${diffusion} \ + time_conditioning=${time_conditioning} \ + T=${TRAIN_T} \ + sampling.num_sample_batches=32 \ + sampling.batch_size=32 \ + sampling.steps=${SAMPLING_STEPS} \ + sampling.use_cache=${sampling_use_cache} \ + +eval.results_csv_path=${results_csv_path} \ + eval.generated_samples_path=${generated_seqs_path} \ + +eval.classifier_model_name_or_path="AdamCodd/distilbert-base-uncased-finetuned-sentiment-amazon" \ + +eval.generative_ppl_model_name_or_path="gpt2-large" \ + ${guidance_args} diff --git a/scripts/eval_lm1b_gen_ppl.sh b/scripts/eval_lm1b_gen_ppl.sh new file mode 100644 index 0000000000000000000000000000000000000000..35f03d48011f34ca1960d2b766e117e4d3b028a0 --- /dev/null +++ b/scripts/eval_lm1b_gen_ppl.sh @@ -0,0 +1,99 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=32000 # server memory requested (per node) +#SBATCH -t 96:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +sbatch \ + --export=ALL,MODEL=${MODEL} \ + --job-name=eval_lm1b_gen_ppl_${MODEL} \ + eval_lm1b_gen_ppl.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh || exit +export HYDRA_FULL_ERROR=1 + +# Expecting: +# - MODEL (choices: ar, mdlm, udlm) +# - SAMPLING_STEPS (optional: default = 128) +# - SEED (optional: default = 1) +# - USE_FLOAT64 (optional: default = False) + +if [ -z "${MODEL}" ]; then + echo "MODEL is not set" + exit 1 +fi +if [ -z "${SAMPLING_STEPS}" ]; then + SAMPLING_STEPS=128 +fi +if [ -z "${SEED}" ]; then + SEED=1 +fi +if [ -z "${USE_FLOAT64}" ]; then + USE_FLOAT64=False +fi + +if [ "${MODEL}" = "ar" ]; then + parameterization="ar" + diffusion="absorbing_state" + TRAIN_T=0 + time_conditioning=False + sampling_use_cache=False + CKPT="${PWD}/outputs/lm1b/ar" +elif [ "${MODEL}" = "mdlm" ]; then + parameterization="subs" + diffusion="absorbing_state" + TRAIN_T=0 + time_conditioning=False + sampling_use_cache=True + CKPT="${PWD}/outputs/lm1b/mdlm" +elif [ "${MODEL}" = "udlm" ]; then + parameterization="d3pm" + diffusion="uniform" + TRAIN_T=0 + time_conditioning=True + sampling_use_cache=False + CKPT="${PWD}/outputs/lm1b/udlm" +else + echo "Invalid MODEL: ${MODEL}" + exit 1 +fi +generated_seqs_path="${CKPT}/samples-lm1b-gen-ppl-eval-float64-${USE_FLOAT64}_add-CLS_T-${SAMPLING_STEPS}_seed-${SEED}.json" + +# shellcheck disable=SC2086 +python -u -m main \ + hydra.output_subdir=null \ + hydra.run.dir="${PWD}" \ + hydra/job_logging=disabled \ + hydra/hydra_logging=disabled \ + seed=${SEED} \ + mode="gen_ppl_eval" \ + eval.checkpoint_path="${CKPT}/checkpoints/last.ckpt" \ + data=lm1b \ + backbone=dit \ + model=small \ + model.length=128 \ + training.guidance=null \ + parameterization=${parameterization} \ + diffusion=${diffusion} \ + time_conditioning=${time_conditioning} \ + T=${TRAIN_T} \ + sampling.num_sample_batches=32 \ + sampling.batch_size=32 \ + sampling.steps=${SAMPLING_STEPS} \ + sampling.use_cache=${sampling_use_cache} \ + sampling.use_float64=${USE_FLOAT64} \ + eval.generated_samples_path=${generated_seqs_path} \ + +eval.generative_ppl_model_name_or_path="gpt2-large" diff --git a/scripts/eval_lm1b_ppl.sh b/scripts/eval_lm1b_ppl.sh new file mode 100644 index 0000000000000000000000000000000000000000..62a21d5992fef23554d5ea3b738a320f42e5bc18 --- /dev/null +++ b/scripts/eval_lm1b_ppl.sh @@ -0,0 +1,87 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=32000 # server memory requested (per node) +#SBATCH -t 96:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +sbatch \ + --export=ALL,MODEL=${MODEL} \ + --job-name=eval_lm1b_ppl_${MODEL} \ + eval_lm1b_ppl.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh || exit +export HYDRA_FULL_ERROR=1 + +# Expecting: +# - MODEL (choices: ar, mdlm, udlm) +# - SEED (optional: default = 1) + +if [ -z "${MODEL}" ]; then + echo "MODEL is not set" + exit 1 +fi + +if [ "${MODEL}" = "ar" ]; then + PARAMETERIZATION="ar" + DIFFUSION="absorbing_state" + TRAIN_T=0 + ZERO_RECON_LOSS=False + TIME_COND=False + BATCH_SIZE=128 + CKPT="${PWD}/outputs/lm1b/ar" +elif [ "${MODEL}" = "mdlm" ]; then + PARAMETERIZATION="subs" + DIFFUSION="absorbing_state" + TRAIN_T=0 + ZERO_RECON_LOSS=False + TIME_COND=False + BATCH_SIZE=128 + CKPT="${PWD}/outputs/lm1b/mdlm" +elif [ "${MODEL}" = "udlm" ]; then + PARAMETERIZATION="d3pm" + DIFFUSION="uniform" + TRAIN_T=0 + ZERO_RECON_LOSS=True + TIME_COND=True + BATCH_SIZE=64 + CKPT="${PWD}/outputs/lm1b/udlm" +else + echo "Invalid MODEL: ${MODEL}" + exit 1 +fi + +# shellcheck disable=SC2086 +python -u -m main \ + hydra.output_subdir=null \ + hydra.run.dir="${PWD}" \ + hydra/job_logging=disabled \ + hydra/hydra_logging=disabled \ + seed=${SEED} \ + mode="ppl_eval" \ + eval.checkpoint_path="${CKPT}/checkpoints/last.ckpt" \ + eval.generate_samples=False \ + loader.eval_batch_size=${BATCH_SIZE} \ + data=lm1b \ + data.wrap=False \ + backbone=dit \ + model=small \ + model.length=128 \ + training.guidance=null \ + parameterization=${PARAMETERIZATION} \ + diffusion=${DIFFUSION} \ + time_conditioning=${TIME_COND} \ + zero_recon_loss=${ZERO_RECON_LOSS} \ + T=${TRAIN_T} diff --git a/scripts/eval_qm9_guidance.sh b/scripts/eval_qm9_guidance.sh new file mode 100644 index 0000000000000000000000000000000000000000..4f19dc67f3842b26178e27c4a624ef197aa3e137 --- /dev/null +++ b/scripts/eval_qm9_guidance.sh @@ -0,0 +1,265 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=32000 # server memory requested (per node) +#SBATCH -t 24:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +PROP= +GUIDANCE= +... additional args for each guidance method ... +sbatch \ + --export=ALL,MODEL=${MODEL},PROP=${PROP},GUIDANCE=${GUIDANCE},... \ + --job-name=eval_qm9_${GUIDANCE}_${PROP}_${MODEL} \ + eval_qm9_guidance.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh || exit +export HYDRA_FULL_ERROR=1 + +# Expecting: +# - MODEL (choices: ar, mdlm, udlm) +# - PROP (choices: qed, ring_count) +# - GUIDANCE (each method has its own required args) +# - CONDITION (optional: default = 1) +# - SAMPLING_STEPS (optional: default = 32) +# - SEED (optional: default = 1) + +if [ -z "${MODEL}" ]; then + echo "MODEL is not set" + exit 1 +fi +if [ -z "${PROP}" ]; then + echo "PROP is not set" + exit 1 +fi +if [ -z "${GUIDANCE}" ]; then + echo "GUIDANCE is not set" + exit 1 +fi +if [ -z "${CONDITION}" ]; then + CONDITION=1 +fi +if [ -z "${SAMPLING_STEPS}" ]; then + SAMPLING_STEPS=32 +fi +if [ -z "${SEED}" ]; then + SEED=1 +fi + + +# CKPT below is unconditional model (will be overridden if GUIDANCE = "cfg") +if [ "${MODEL}" = "ar" ]; then + parameterization="ar" + diffusion="absorbing_state" + TRAIN_T=0 + time_conditioning=False + sampling_use_cache=False + SAMPLING_STEPS=32 + CKPT="${PWD}/outputs/qm9/ar_no-guidance" +elif [ "${MODEL}" = "mdlm" ]; then + parameterization="subs" + diffusion="absorbing_state" + TRAIN_T=0 + time_conditioning=False + sampling_use_cache=True + CKPT="${PWD}/outputs/qm9/mdlm_no-guidance" +elif [ "${MODEL}" = "udlm" ]; then + parameterization="d3pm" + diffusion="uniform" + TRAIN_T=0 + time_conditioning=True + sampling_use_cache=False + CKPT="${PWD}/outputs/qm9/udlm_no-guidance" +else + echo "Invalid MODEL: ${MODEL}" + exit 1 +fi + + +guidance_args="guidance=${GUIDANCE} guidance.condition=${CONDITION}" +###### CFG ###### +if [ "${GUIDANCE}" == "cfg" ]; then + # Expecting: + # - GAMMA + if [ -z "${GAMMA}" ]; then + echo "GAMMA is not set" + exit 1 + fi + if [ "${PROP}" = "qed" ]; then + if [ "${MODEL}" = "ar" ]; then + CKPT="${PWD}/outputs/qm9/ar_qed" + elif [ "${MODEL}" = "mdlm" ]; then + CKPT="${PWD}/outputs/qm9/mdlm_qed" + elif [ "${MODEL}" = "udlm" ]; then + CKPT="${PWD}/outputs/qm9/udlm_qed" + fi + elif [ "${PROP}" = "ring_count" ]; then + if [ "${MODEL}" = "ar" ]; then + CKPT="${PWD}/outputs/qm9/ar_ring_count" + elif [ "${MODEL}" = "mdlm" ]; then + CKPT="${PWD}/outputs/qm9/mdlm_ring_count" + elif [ "${MODEL}" = "udlm" ]; then + CKPT="${PWD}/outputs/qm9/udlm_ring_count" + fi + else + echo "Invalid PROP: ${PROP}" + exit 1 + fi + guidance_args="${guidance_args} guidance.gamma=${GAMMA}" + results_csv_path="${CKPT}/qm9-eval-${GUIDANCE}_${PROP}_T-${SAMPLING_STEPS}_gamma-${GAMMA}_seed-${SEED}.csv" + generated_seqs_path="${CKPT}/samples-qm9-eval-${GUIDANCE}_${PROP}_T-${SAMPLING_STEPS}_gamma-${GAMMA}_seed-${SEED}.json" +###### FUDGE / CBG ###### +elif [ "${GUIDANCE}" = "fudge" ] || [ "${GUIDANCE}" = "cbg" ]; then + # Expecting: + # - GAMMA + # - USE_APPROX (for cbg) + if [ -z "${GAMMA}" ]; then + echo "GAMMA is not set" + exit 1 + fi + if [ "${PROP}" = "qed" ]; then + if [ "${MODEL}" = "ar" ]; then + CLASS_CKPT="${PWD}/outputs/qm9/fudge_classifier/qed" + elif [ "${MODEL}" = "mdlm" ]; then + CLASS_CKPT="${PWD}/outputs/qm9/classifier/qed_absorbing_state_T-0" + elif [ "${MODEL}" = "udlm" ]; then + CLASS_CKPT="${PWD}/outputs/qm9/classifier/qed_uniform_T-0" + fi + elif [ "${PROP}" = "ring_count" ]; then + if [ "${MODEL}" = "ar" ]; then + CLASS_CKPT="${PWD}/outputs/qm9/fudge_classifier/ring_count" + elif [ "${MODEL}" = "mdlm" ]; then + CLASS_CKPT="${PWD}/outputs/qm9/classifier/ring_count_absorbing_state_T-0" + elif [ "${MODEL}" = "udlm" ]; then + CLASS_CKPT="${PWD}/outputs/qm9/classifier/ring_count_uniform_T-0" + fi + else + echo "Invalid PROP: ${PROP}" + exit 1 + fi + guidance_args="${guidance_args} classifier_model=tiny-classifier classifier_backbone=dit guidance.classifier_checkpoint_path=${CLASS_CKPT}/checkpoints/best.ckpt guidance.gamma=${GAMMA}" + if [ "${GUIDANCE}" = "fudge" ]; then + guidance_args="${guidance_args} guidance.topk=40 classifier_model.pooling=no_pooling" # Use full vocab size for topk + fi + if [ "${GUIDANCE}" = "cbg" ]; then + if [ -z "${USE_APPROX}" ]; then + echo "USE_APPROX is not set" + exit 1 + fi + guidance_args="${guidance_args} guidance.use_approx=${USE_APPROX}" + results_csv_path="${CKPT}/qm9-eval-${GUIDANCE}_approx-${USE_APPROX}_${PROP}_T-${SAMPLING_STEPS}_gamma-${GAMMA}_seed-${SEED}.csv" + generated_seqs_path="${CKPT}/samples-qm9-eval-${GUIDANCE}_approx-${USE_APPROX}_${PROP}_T-${SAMPLING_STEPS}_gamma-${GAMMA}_seed-${SEED}.json" + else + results_csv_path="${CKPT}/qm9-eval-${GUIDANCE}_${PROP}_T-${SAMPLING_STEPS}_gamma-${GAMMA}_seed-${SEED}.csv" + generated_seqs_path="${CKPT}/samples-qm9-eval-${GUIDANCE}_${PROP}_T-${SAMPLING_STEPS}_gamma-${GAMMA}_seed-${SEED}.json" + fi +###### PPLM / NOS ###### +elif [ "${GUIDANCE}" = "pplm" ] || [ "${GUIDANCE}" = "nos" ]; then + if [ "${GUIDANCE}" = "pplm" ]; then + # Expecting: + # - NUM_PPLM_STEPS + # - PPLM_STEP_SIZE + # - PPLM_STABILITY_COEF + if [ -z "${NUM_PPLM_STEPS}" ]; then + echo "NUM_PPLM_STEPS is not set" + exit 1 + fi + if [ -z "${PPLM_STEP_SIZE}" ]; then + echo "PPLM_STEP_SIZE is not set" + exit 1 + fi + if [ -z "${PPLM_STABILITY_COEF}" ]; then + echo "PPLM_STABILITY_COEF is not set" + exit 1 + fi + guidance_args="${guidance_args} guidance.num_pplm_steps=${NUM_PPLM_STEPS} guidance.pplm_step_size=${PPLM_STEP_SIZE} guidance.pplm_stability_coef=${PPLM_STABILITY_COEF}" + results_csv_path="${CKPT}/qm9-eval-${GUIDANCE}_${PROP}_T-${SAMPLING_STEPS}_NUM_PPLM_STEPS-${NUM_PPLM_STEPS}_PPLM_STEP_SIZE-${PPLM_STEP_SIZE}_PPLM_STABILITY_COEF-${PPLM_STABILITY_COEF}_seed-${SEED}.csv" + generated_seqs_path="${CKPT}/samples_qm9-eval-${GUIDANCE}_${PROP}_T-${SAMPLING_STEPS}_NUM_PPLM_STEPS-${NUM_PPLM_STEPS}_PPLM_STEP_SIZE-${PPLM_STEP_SIZE}_PPLM_STABILITY_COEF-${PPLM_STABILITY_COEF}_seed-${SEED}.json" + else + # Expecting: + # - NUM_NOS_STEPS + # - NOS_STEP_SIZE + # - NOS_STABILITY_COEF + if [ -z "${NUM_NOS_STEPS}" ]; then + echo "NUM_NOS_STEPS is not set" + exit 1 + fi + if [ -z "${NOS_STEP_SIZE}" ]; then + echo "NOS_STEP_SIZE is not set" + exit 1 + fi + if [ -z "${NOS_STABILITY_COEF}" ]; then + echo "NOS_STABILITY_COEF is not set" + exit 1 + fi + guidance_args="${guidance_args} guidance.num_nos_steps=${NUM_NOS_STEPS} guidance.nos_step_size=${NOS_STEP_SIZE} guidance.nos_stability_coef=${NOS_STABILITY_COEF}" + results_csv_path="${CKPT}/qm9-eval-${GUIDANCE}_${PROP}_T-${SAMPLING_STEPS}_NUM_NOS_STEPS-${NUM_NOS_STEPS}_NOS_STEP_SIZE-${NOS_STEP_SIZE}_NOS_STABILITY_COEF-${NOS_STABILITY_COEF}_seed-${SEED}.csv" + generated_seqs_path="${CKPT}/samples_qm9-eval-${GUIDANCE}_${PROP}_T-${SAMPLING_STEPS}_NUM_NOS_STEPS-${NUM_NOS_STEPS}_NOS_STEP_SIZE-${NOS_STEP_SIZE}_NOS_STABILITY_COEF-${NOS_STABILITY_COEF}_seed-${SEED}.json" + fi + + if [ "${PROP}" = "qed" ]; then + if [ "${MODEL}" = "ar" ]; then + CLASS_CKPT="${PWD}/outputs/qm9/pplm_classifier/qed_ar" + elif [ "${MODEL}" = "mdlm" ]; then + CLASS_CKPT="${PWD}/outputs/qm9/pplm_classifier/qed_mdlm" + elif [ "${MODEL}" = "udlm" ]; then + CLASS_CKPT="${PWD}/outputs/qm9/pplm_classifier/qed_udlm" + fi + elif [ "${PROP}" = "ring_count" ]; then + if [ "${MODEL}" = "ar" ]; then + CLASS_CKPT="${PWD}/outputs/qm9/pplm_classifier/ring_count_ar" + elif [ "${MODEL}" = "mdlm" ]; then + CLASS_CKPT="${PWD}/outputs/qm9/pplm_classifier/ring_count_mdlm" + elif [ "${MODEL}" = "udlm" ]; then + CLASS_CKPT="${PWD}/outputs/qm9/pplm_classifier/ring_count_udlm" + fi + else + echo "Invalid PROP: ${PROP}" + exit 1 + fi + guidance_args="${guidance_args} classifier_model=small-classifier classifier_backbone=dit guidance.classifier_checkpoint_path=${CLASS_CKPT}/checkpoints/best.ckpt" +else + echo "Invalid GUIDANCE: ${GUIDANCE}" + exit 1 +fi + +# shellcheck disable=SC2086 +python -u guidance_eval/qm9_eval.py \ + hydra.output_subdir=null \ + hydra.run.dir="${CKPT}" \ + hydra/job_logging=disabled \ + hydra/hydra_logging=disabled \ + seed=${SEED} \ + mode=qm9_eval \ + eval.checkpoint_path="${CKPT}/checkpoints/best.ckpt" \ + data=qm9 \ + data.label_col="${PROP}" \ + data.label_col_pctile=90 \ + data.num_classes=2 \ + model=small \ + backbone=dit \ + model.length=32 \ + training.guidance=null \ + parameterization=${parameterization} \ + diffusion=${diffusion} \ + time_conditioning=${time_conditioning} \ + T=${TRAIN_T} \ + sampling.num_sample_batches=64 \ + sampling.batch_size=16 \ + sampling.steps=${SAMPLING_STEPS} \ + sampling.use_cache=${sampling_use_cache} \ + +eval.results_csv_path=${results_csv_path} \ + eval.generated_samples_path=${generated_seqs_path} \ + ${guidance_args} diff --git a/scripts/eval_ten_species_guidance.sh b/scripts/eval_ten_species_guidance.sh new file mode 100644 index 0000000000000000000000000000000000000000..c5a0efa766dc351d961d713202c6511f354c5530 --- /dev/null +++ b/scripts/eval_ten_species_guidance.sh @@ -0,0 +1,229 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=32000 # server memory requested (per node) +#SBATCH -t 24:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +GUIDANCE= +... additional args for each guidance method ... +sbatch \ + --export=ALL,MODEL=${MODEL},GUIDANCE=${GUIDANCE},... \ + --job-name=eval_ten_species_${GUIDANCE}_${MODEL} \ + eval_ten_species_guidance.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh || exit +export HYDRA_FULL_ERROR=1 + +# Expecting: +# - MODEL (choices: ar, mdlm, udlm) +# - GUIDANCE (each method has its own required args) +# - SAMPLING_STEPS (optional: default = 32) +# - SEED (optional: default = 1) + +if [ -z "${MODEL}" ]; then + echo "MODEL is not set" + exit 1 +fi +if [ -z "${GUIDANCE}" ]; then + echo "GUIDANCE is not set" + exit 1 +fi +if [ -z "${SAMPLING_STEPS}" ]; then + SAMPLING_STEPS=128 +fi +if [ -z "${SEED}" ]; then + SEED=1 +fi + + +# CKPT below is unconditional model (will be overridden if GUIDANCE = "cfg") +if [ "${MODEL}" = "ar" ]; then + parameterization="ar" + diffusion="absorbing_state" + TRAIN_T=0 + BIDIRECTIONAL=False + BIDIRECTIONAL_STRATEGY=null + BIDIRECTIONAL_WEIGHT_TIE=null + time_conditioning=False + sampling_use_cache=False + SAMPLING_STEPS=32768 + CKPT="${PWD}/outputs/ten_species/ar_no-guidance" +elif [ "${MODEL}" = "mdlm" ]; then + parameterization="subs" + diffusion="absorbing_state" + TRAIN_T=0 + BIDIRECTIONAL=True + BIDIRECTIONAL_STRATEGY="add" + BIDIRECTIONAL_WEIGHT_TIE=True + time_conditioning=False + sampling_use_cache=True + CKPT="${PWD}/outputs/ten_species/mdlm_no-guidance" +elif [ "${MODEL}" = "udlm" ]; then + parameterization="d3pm" + diffusion="uniform" + TRAIN_T=0 + BIDIRECTIONAL=True + BIDIRECTIONAL_STRATEGY="add" + BIDIRECTIONAL_WEIGHT_TIE=True + time_conditioning=True + sampling_use_cache=False + CKPT="${PWD}/outputs/ten_species/udlm_no-guidance" +else + echo "Invalid MODEL: ${MODEL}" + exit 1 +fi + + +guidance_args="guidance=${GUIDANCE}" +###### CFG ###### +if [ "${GUIDANCE}" == "cfg" ]; then + # Expecting: + # - GAMMA + if [ -z "${GAMMA}" ]; then + echo "GAMMA is not set" + exit 1 + fi + if [ "${MODEL}" = "ar" ]; then + CKPT="${PWD}/outputs/ten_species/ar" + elif [ "${MODEL}" = "mdlm" ]; then + CKPT="${PWD}/outputs/ten_species/mdlm" + elif [ "${MODEL}" = "udlm" ]; then + CKPT="${PWD}/outputs/ten_species/udlm" + fi + guidance_args="${guidance_args} guidance.gamma=${GAMMA}" + results_csv_path="${CKPT}/ten_species-eval-${GUIDANCE}_T-${SAMPLING_STEPS}_gamma-${GAMMA}_seed-${SEED}.csv" + generated_seqs_path="${CKPT}/samples-ten_species-eval-${GUIDANCE}_T-${SAMPLING_STEPS}_gamma-${GAMMA}_seed-${SEED}.json" +###### FUDGE / CBG ###### +elif [ "${GUIDANCE}" = "fudge" ] || [ "${GUIDANCE}" = "cbg" ]; then + # Expecting: + # - GAMMA + # - USE_APPROX (for cbg) + if [ -z "${GAMMA}" ]; then + echo "GAMMA is not set" + exit 1 + fi + if [ "${MODEL}" = "ar" ]; then + CLASS_CKPT="${PWD}/outputs/ten_species/fudge_classifier" + elif [ "${MODEL}" = "mdlm" ]; then + CLASS_CKPT="${PWD}/outputs/ten_species/classifier/absorbing_state_T-0" + elif [ "${MODEL}" = "udlm" ]; then + CLASS_CKPT="${PWD}/outputs/ten_species/classifier/uniform_T-0" + fi + guidance_args="${guidance_args} classifier_model=tiny-dimamba-classifier classifier_backbone=dimamba classifier_model.bidirectional=${BIDIRECTIONAL} classifier_model.bidirectional_strategy=${BIDIRECTIONAL_STRATEGY} classifier_model.bidirectional_weight_tie=${BIDIRECTIONAL_WEIGHT_TIE} guidance.classifier_checkpoint_path=${CLASS_CKPT}/checkpoints/best.ckpt guidance.gamma=${GAMMA}" + if [ "${GUIDANCE}" = "fudge" ]; then + guidance_args="${guidance_args} guidance.topk=12 classifier_model.pooling=no_pooling" # Use full vocab size for topk + fi + if [ "${GUIDANCE}" = "cbg" ]; then + if [ -z "${USE_APPROX}" ]; then + echo "USE_APPROX is not set" + exit 1 + fi + guidance_args="${guidance_args} guidance.use_approx=${USE_APPROX}" + results_csv_path="${CKPT}/ten_species-eval-${GUIDANCE}_approx-${USE_APPROX}_T-${SAMPLING_STEPS}_gamma-${GAMMA}_seed-${SEED}.csv" + generated_seqs_path="${CKPT}/samples-ten_species-eval-${GUIDANCE}_approx-${USE_APPROX}_T-${SAMPLING_STEPS}_gamma-${GAMMA}_seed-${SEED}.json" + else + results_csv_path="${CKPT}/ten_species-eval-${GUIDANCE}_T-${SAMPLING_STEPS}_gamma-${GAMMA}_seed-${SEED}.csv" + generated_seqs_path="${CKPT}/samples-ten_species-eval-${GUIDANCE}_T-${SAMPLING_STEPS}_gamma-${GAMMA}_seed-${SEED}.json" + fi +###### PPLM / NOS ###### +elif [ "${GUIDANCE}" = "pplm" ] || [ "${GUIDANCE}" = "nos" ]; then + if [ "${GUIDANCE}" = "pplm" ]; then + # Expecting: + # - NUM_PPLM_STEPS + # - PPLM_STEP_SIZE + # - PPLM_STABILITY_COEF + if [ -z "${NUM_PPLM_STEPS}" ]; then + echo "NUM_PPLM_STEPS is not set" + exit 1 + fi + if [ -z "${PPLM_STEP_SIZE}" ]; then + echo "PPLM_STEP_SIZE is not set" + exit 1 + fi + if [ -z "${PPLM_STABILITY_COEF}" ]; then + echo "PPLM_STABILITY_COEF is not set" + exit 1 + fi + guidance_args="${guidance_args} guidance.num_pplm_steps=${NUM_PPLM_STEPS} guidance.pplm_step_size=${PPLM_STEP_SIZE} guidance.pplm_stability_coef=${PPLM_STABILITY_COEF}" + results_csv_path="${CKPT}/ten_species-eval-${GUIDANCE}_T-${SAMPLING_STEPS}_NUM_PPLM_STEPS-${NUM_PPLM_STEPS}_PPLM_STEP_SIZE-${PPLM_STEP_SIZE}_PPLM_STABILITY_COEF-${PPLM_STABILITY_COEF}_seed-${SEED}.csv" + generated_seqs_path="${CKPT}/samples_ten_species-eval-${GUIDANCE}_T-${SAMPLING_STEPS}_NUM_PPLM_STEPS-${NUM_PPLM_STEPS}_PPLM_STEP_SIZE-${PPLM_STEP_SIZE}_PPLM_STABILITY_COEF-${PPLM_STABILITY_COEF}_seed-${SEED}.json" + else + # Expecting: + # - NUM_NOS_STEPS + # - NOS_STEP_SIZE + # - NOS_STABILITY_COEF + if [ -z "${NUM_NOS_STEPS}" ]; then + echo "NUM_NOS_STEPS is not set" + exit 1 + fi + if [ -z "${NOS_STEP_SIZE}" ]; then + echo "NOS_STEP_SIZE is not set" + exit 1 + fi + if [ -z "${NOS_STABILITY_COEF}" ]; then + echo "NOS_STABILITY_COEF is not set" + exit 1 + fi + guidance_args="${guidance_args} guidance.num_nos_steps=${NUM_NOS_STEPS} guidance.nos_step_size=${NOS_STEP_SIZE} guidance.nos_stability_coef=${NOS_STABILITY_COEF}" + results_csv_path="${CKPT}/ten_species-eval-${GUIDANCE}_T-${SAMPLING_STEPS}_NUM_NOS_STEPS-${NUM_NOS_STEPS}_NOS_STEP_SIZE-${NOS_STEP_SIZE}_NOS_STABILITY_COEF-${NOS_STABILITY_COEF}_seed-${SEED}.csv" + generated_seqs_path="${CKPT}/samples_ten_species-eval-${GUIDANCE}_T-${SAMPLING_STEPS}_NUM_NOS_STEPS-${NUM_NOS_STEPS}_NOS_STEP_SIZE-${NOS_STEP_SIZE}_NOS_STABILITY_COEF-${NOS_STABILITY_COEF}_seed-${SEED}.json" + fi + + if [ "${MODEL}" = "ar" ]; then + CLASS_CKPT="${PWD}/outputs/ten_species/pplm_classifier/ar" + elif [ "${MODEL}" = "mdlm" ]; then + CLASS_CKPT="${PWD}/outputs/ten_species/pplm_classifier/mdlm" + elif [ "${MODEL}" = "udlm" ]; then + CLASS_CKPT="${PWD}/outputs/ten_species/pplm_classifier/udlm" + fi + guidance_args="${guidance_args} classifier_model=dimamba-classifier classifier_backbone=dimamba classifier_model.bidirectional=${BIDIRECTIONAL} classifier_model.bidirectional_strategy=${BIDIRECTIONAL_STRATEGY} classifier_model.bidirectional_weight_tie=${BIDIRECTIONAL_WEIGHT_TIE} guidance.classifier_checkpoint_path=${CLASS_CKPT}/checkpoints/best.ckpt" +else + echo "Invalid GUIDANCE: ${GUIDANCE}" + exit 1 +fi + +# shellcheck disable=SC2086 +python -u guidance_eval/ten_species_eval.py \ + hydra.output_subdir=null \ + hydra.run.dir="${CKPT}" \ + hydra/job_logging=disabled \ + hydra/hydra_logging=disabled \ + seed=${SEED} \ + mode=ten_species_eval \ + eval.checkpoint_path="${CKPT}/checkpoints/best.ckpt" \ + data=ten_species \ + backbone=dimamba \ + model=dimamba \ + model.bidirectional=${BIDIRECTIONAL} \ + model.bidirectional_strategy=${BIDIRECTIONAL_STRATEGY} \ + model.bidirectional_weight_tie=${BIDIRECTIONAL_WEIGHT_TIE} \ + model.length=32768 \ + training.guidance=null \ + parameterization=${parameterization} \ + diffusion=${diffusion} \ + time_conditioning=${time_conditioning} \ + T=${TRAIN_T} \ + sampling.num_sample_batches=4 \ + sampling.batch_size=16 \ + sampling.steps=${SAMPLING_STEPS} \ + sampling.use_cache=${sampling_use_cache} \ + +eval.results_csv_path=${results_csv_path} \ + eval.generated_samples_path=${generated_seqs_path} \ + ${guidance_args} \ + +eval.train_weights_path=$(realpath ./guidance_eval/train_set_species_frequency.pt) \ + +eval.val_weights_path=$(realpath ./guidance_eval/validation_set_species_frequency.pt) \ + +eval.eval_classifier_checkpoint_path=$(realpath ./outputs/ten_species/eval_classifier/hyenadna-small-32k_from-scratch_nlayer-8/checkpoints/best.ckpt) \ + +eval.kmer_freqs_path=$(realpath ./guidance_eval/validation_set_kmer_stats.pt) diff --git a/scripts/eval_text8_gen_ppl.sh b/scripts/eval_text8_gen_ppl.sh new file mode 100644 index 0000000000000000000000000000000000000000..2b577d4ff33f6a28ddd2d9f87f1b57f6c71e4e7b --- /dev/null +++ b/scripts/eval_text8_gen_ppl.sh @@ -0,0 +1,94 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=32000 # server memory requested (per node) +#SBATCH -t 96:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +sbatch \ + --export=ALL,MODEL=${MODEL} \ + --job-name=eval_text8_gen_ppl_${MODEL} \ + eval_text8_gen_ppl.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh || exit +export HYDRA_FULL_ERROR=1 + +# Expecting: +# - MODEL (choices: ar, mdlm, udlm) +# - SAMPLING_STEPS (optional: default = 128) +# - SEED (optional: default = 1) + +if [ -z "${MODEL}" ]; then + echo "MODEL is not set" + exit 1 +fi +if [ -z "${SAMPLING_STEPS}" ]; then + SAMPLING_STEPS=128 +fi +if [ -z "${SEED}" ]; then + SEED=1 +fi + +if [ "${MODEL}" = "ar" ]; then + parameterization="ar" + diffusion="absorbing_state" + TRAIN_T=0 + time_conditioning=False + sampling_use_cache=False + CKPT="${PWD}/outputs/text8/ar" +elif [ "${MODEL}" = "mdlm" ]; then + parameterization="subs" + diffusion="absorbing_state" + TRAIN_T=0 + time_conditioning=False + sampling_use_cache=True + CKPT="${PWD}/outputs/text8/mdlm" +elif [ "${MODEL}" = "udlm" ]; then + parameterization="d3pm" + diffusion="uniform" + TRAIN_T=0 + time_conditioning=True + sampling_use_cache=False + CKPT="${PWD}/outputs/text8/udlm" +else + echo "Invalid MODEL: ${MODEL}" + exit 1 +fi +generated_seqs_path="${CKPT}/samples-text8-gen-ppl-eval-_T-${SAMPLING_STEPS}_seed-${SEED}.json" + +# shellcheck disable=SC2086 +python -u -m main \ + hydra.output_subdir=null \ + hydra.run.dir="${CKPT}" \ + hydra/job_logging=disabled \ + hydra/hydra_logging=disabled \ + seed=${SEED} \ + mode="gen_ppl_eval" \ + eval.checkpoint_path="${CKPT}/checkpoints/best.ckpt" \ + data=text8 \ + backbone=dit \ + model=small \ + model.length=256 \ + training.guidance=null \ + parameterization=${parameterization} \ + diffusion=${diffusion} \ + time_conditioning=${time_conditioning} \ + T=${TRAIN_T} \ + sampling.num_sample_batches=32 \ + sampling.batch_size=32 \ + sampling.steps=${SAMPLING_STEPS} \ + sampling.use_cache=${sampling_use_cache} \ + eval.generated_samples_path=${generated_seqs_path} \ + +eval.generative_ppl_model_name_or_path="gpt2-large" diff --git a/scripts/train_amazon_polarity.sh b/scripts/train_amazon_polarity.sh new file mode 100644 index 0000000000000000000000000000000000000000..c631ca0adcde2374be4b7086a85988291bc4f110 --- /dev/null +++ b/scripts/train_amazon_polarity.sh @@ -0,0 +1,107 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=64000 # server memory requested (per node) +#SBATCH -t 960:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=8 +#SBATCH --gres=gpu:8 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +sbatch \ + --export=ALL,MODEL=${MODEL} \ + --job-name=train_amazon_polarity_${MODEL} \ + train_amazon_polarity.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh +export NCCL_P2P_LEVEL=NVL +export HYDRA_FULL_ERROR=1 + +# Expecting: +# - MODEL (ar, mdlm, udlm) +# - USE_SIMPLE_CE_LOSS (True, False; optional, default: False) +if [ -z "${MODEL}" ]; then + echo "MODEL is not set" + exit 1 +fi +if [ -z "${USE_SIMPLE_CE_LOSS}" ]; then + USE_SIMPLE_CE_LOSS=False +fi +RUN_NAME="${MODEL}" +if [ "${USE_SIMPLE_CE_LOSS}" = "True" ]; then + RUN_NAME="${RUN_NAME}_simple-ce" +fi + +if [ "${MODEL}" = "ar" ]; then + # AR + DIFFUSION="absorbing_state" + PARAMETERIZATION="ar" + T=0 + TIME_COND=False + ZERO_RECON_LOSS=False + sampling_use_cache=False +elif [ "${MODEL}" = "mdlm" ]; then + # MDLM + DIFFUSION="absorbing_state" + PARAMETERIZATION="subs" + T=0 + TIME_COND=False + ZERO_RECON_LOSS=False + sampling_use_cache=True +elif [ "${MODEL}" = "udlm" ]; then + # UDLM + DIFFUSION="uniform" + PARAMETERIZATION="d3pm" + T=0 + TIME_COND=True + ZERO_RECON_LOSS=True + sampling_use_cache=False +else + echo "MODEL must be one of ar, mdlm, udlm" + exit 1 +fi + +# To enable preemption re-loading, set `hydra.run.dir` or +# `checkpointing.save_dir` explicitly. +srun python -u -m main \ + diffusion="${DIFFUSION}" \ + parameterization="${PARAMETERIZATION}" \ + T=${T} \ + time_conditioning=${TIME_COND} \ + zero_recon_loss=${ZERO_RECON_LOSS} \ + data="amazon_polarity" \ + data.wrap=False \ + data.tokenizer_name_or_path=bert-base-uncased \ + loader.global_batch_size=512 \ + loader.eval_global_batch_size=1024 \ + loader.batch_size=64 \ + loader.eval_batch_size=128 \ + backbone="dit" \ + model=small \ + model.length=128 \ + optim.lr=3e-4 \ + training.guidance=null \ + callbacks.checkpoint_every_n_steps.every_n_train_steps=40_000 \ + training.compute_loss_on_pad_tokens=True \ + trainer.log_every_n_steps=100 \ + trainer.max_steps=-1 \ + +trainer.max_epochs=60 \ + trainer.val_check_interval=1.0 \ + trainer.precision=bf16 \ + eval.generate_samples=True \ + sampling.num_sample_batches=1 \ + sampling.batch_size=2 \ + sampling.use_cache=${sampling_use_cache} \ + sampling.steps=128 \ + training.use_simple_ce_loss=${USE_SIMPLE_CE_LOSS} \ + wandb.name="amazon_polarity_${RUN_NAME}" \ + hydra.run.dir="${PWD}/outputs/amazon_polarity/${RUN_NAME}" diff --git a/scripts/train_amazon_polarity_classifier.sh b/scripts/train_amazon_polarity_classifier.sh new file mode 100644 index 0000000000000000000000000000000000000000..3dd5e9fc9a1ee28aa3d9b81f7151cbd1d2ed4911 --- /dev/null +++ b/scripts/train_amazon_polarity_classifier.sh @@ -0,0 +1,64 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=32000 # server memory requested (per node) +#SBATCH -t 960:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +sbatch \ + --export=ALL,DIFFUSION=${DIFFUSION} \ + --job-name=train_amazon_classifier_${DIFFUSION} \ + train_amazon_polarity_classifier.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh +export NCCL_P2P_LEVEL=NVL +export HYDRA_FULL_ERROR=1 + +# Expecting: +# - DIFFUSION (absorbing_state or uniform) +# - PROP (qed or ring_count) +if [ -z "${DIFFUSION}" ]; then + echo "DIFFUSION is not set" + exit 1 +fi +T=0 +RUN_NAME="${DIFFUSION}_T-${T}" + +# To enable preemption re-loading, set `hydra.run.dir` or +srun python -u -m main \ + mode=train_classifier \ + diffusion=${DIFFUSION} \ + T=${T} \ + data=amazon_polarity \ + data.wrap=False \ + data.tokenizer_name_or_path=bert-base-uncased \ + data.label_col=label \ + data.num_classes=2 \ + loader.global_batch_size=512 \ + loader.eval_global_batch_size=1024 \ + classifier_backbone=dit \ + classifier_model=tiny-classifier \ + model.length=128 \ + optim.lr=3e-4 \ + lr_scheduler=cosine_decay_warmup \ + lr_scheduler.warmup_t=1_000 \ + lr_scheduler.lr_min=3e-6 \ + callbacks.checkpoint_every_n_steps.every_n_train_steps=40_000 \ + callbacks.checkpoint_monitor.monitor=val/cross_entropy \ + trainer.max_steps=400_000 \ + trainer.val_check_interval=1.0 \ + wandb.group=train_classifier \ + wandb.name="amazon_polarity-classifier_${RUN_NAME}" \ + hydra.run.dir="${PWD}/outputs/amazon_polarity/classifier/${RUN_NAME}" diff --git a/scripts/train_amazon_polarity_fudge_classifier.sh b/scripts/train_amazon_polarity_fudge_classifier.sh new file mode 100644 index 0000000000000000000000000000000000000000..b21a9bf45c74ec050eff768298983622c3c9ef93 --- /dev/null +++ b/scripts/train_amazon_polarity_fudge_classifier.sh @@ -0,0 +1,57 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=64000 # server memory requested (per node) +#SBATCH -t 960:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +sbatch \ + --export=ALL,MODEL=${MODEL} \ + --job-name=train_cifar10_${MODEL} \ + train_cifar10_unet_guidance.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh +export NCCL_P2P_LEVEL=NVL +export HYDRA_FULL_ERROR=1 + +# Expecting: +# - MODEL (mdlm, udlm) +if [ -z "${MODEL}" ]; then + echo "MODEL is not set" + exit 1 +fi + +RUN_NAME="${MODEL}" +T=0 +if [ "${MODEL}" = "mdlm" ]; then + PARAMETERIZATION=subs + DIFFUSION="absorbing_state" + ZERO_RECON_LOSS=False + time_conditioning=False + sampling_use_cache=True +elif [ "${MODEL}" = "udlm" ]; then + PARAMETERIZATION=d3pm + DIFFUSION="uniform" + ZERO_RECON_LOSS=True + time_conditioning=True + sampling_use_cache=False +else + echo "MODEL must be one of mdlm, udlm" + exit 1 +fi + +# To enable preemption re-loading, set `hydra.run.dir` or +srun python -u -m main \ + is_vision=True \ + diffusion=${DIFFUSION} \ + parameterization=${PARAMETERIZATION} \ + T=${T} \ + time_conditioning=${time_conditioning} \ + zero_recon_loss=${ZERO_RECON_LOSS} \ + data=cifar10 \ + data.train=${DATASET_PATH} \ + data.valid=${DATASET_PATH} \ + loader.global_batch_size=512 \ + loader.eval_global_batch_size=64 \ + backbone=unet \ + model=unet \ + optim.lr=2e-4 \ + lr_scheduler=constant_warmup \ + lr_scheduler.num_warmup_steps=5000 \ + callbacks.checkpoint_every_n_steps.every_n_train_steps=10_000 \ + trainer.max_steps=300_000 \ + trainer.val_check_interval=10_000 \ + +trainer.check_val_every_n_epoch=null \ + training.guidance.cond_dropout=0.1 \ + eval.generate_samples=True \ + sampling.num_sample_batches=1 \ + sampling.batch_size=2 \ + sampling.use_cache=${sampling_use_cache} \ + sampling.steps=128 \ + wandb.name="cifar10_${RUN_NAME}" \ + hydra.run.dir="${PWD}/outputs/cifar10/${RUN_NAME}" diff --git a/scripts/train_lm1b.sh b/scripts/train_lm1b.sh new file mode 100644 index 0000000000000000000000000000000000000000..9bfbd205ed4a8c613db791108941bd1543bc7bc6 --- /dev/null +++ b/scripts/train_lm1b.sh @@ -0,0 +1,99 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=64000 # server memory requested (per node) +#SBATCH -t 960:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=8 +#SBATCH --gres=gpu:8 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +sbatch \ + --export=ALL,MODEL=${MODEL} \ + --job-name=train_lm1b_pad_${MODEL} \ + train_lm1b_pad.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh +export NCCL_P2P_LEVEL=NVL +export HYDRA_FULL_ERROR=1 + +# Expecting: +# - MODEL (ar, mdlm, udlm) +if [ -z "${MODEL}" ]; then + echo "MODEL is not set" + exit 1 +fi +RUN_NAME="${MODEL}" + +if [ "${MODEL}" = "ar" ]; then + # AR + DIFFUSION="absorbing_state" + PARAMETERIZATION="ar" + T=0 + TIME_COND=False + ZERO_RECON_LOSS=False + sampling_use_cache=False +elif [ "${MODEL}" = "mdlm" ]; then + # MDLM + DIFFUSION="absorbing_state" + PARAMETERIZATION="subs" + T=0 + TIME_COND=False + ZERO_RECON_LOSS=False + sampling_use_cache=True +elif [ "${MODEL}" = "udlm" ]; then + # UDLM + DIFFUSION="uniform" + PARAMETERIZATION="d3pm" + T=0 + TIME_COND=True + ZERO_RECON_LOSS=True + sampling_use_cache=False +else + echo "MODEL must be one of ar, mdlm, udlm" + exit 1 +fi + + +# To enable preemption re-loading, set `hydra.run.dir` or +# `checkpointing.save_dir` explicitly. +srun python -u -m main \ + diffusion="${DIFFUSION}" \ + parameterization="${PARAMETERIZATION}" \ + T=${T} \ + time_conditioning=${TIME_COND} \ + zero_recon_loss=${ZERO_RECON_LOSS} \ + data="lm1b" \ + data.wrap=False \ + data.tokenizer_name_or_path=bert-base-uncased \ + loader.global_batch_size=512 \ + loader.eval_global_batch_size=1024 \ + loader.batch_size=64 \ + loader.eval_batch_size=128 \ + backbone="dit" \ + model=small \ + model.length=128 \ + optim.lr=3e-4 \ + training.guidance=null \ + training.compute_loss_on_pad_tokens=False \ + callbacks.checkpoint_every_n_steps.every_n_train_steps=100_000 \ + trainer.log_every_n_steps=100 \ + trainer.max_steps=1_000_000 \ + trainer.precision=bf16 \ + trainer.val_check_interval=10_000 \ + eval.generate_samples=True \ + sampling.num_sample_batches=1 \ + sampling.batch_size=2 \ + sampling.use_cache=${sampling_use_cache} \ + sampling.steps=128 \ + wandb.name="lm1b_${RUN_NAME}" \ + hydra.run.dir="${PWD}/outputs/lm1b/${RUN_NAME}" diff --git a/scripts/train_qm9_classifier.sh b/scripts/train_qm9_classifier.sh new file mode 100644 index 0000000000000000000000000000000000000000..d7ebd26834d28ae7122e98dff129ff51b4b320a4 --- /dev/null +++ b/scripts/train_qm9_classifier.sh @@ -0,0 +1,68 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=64000 # server memory requested (per node) +#SBATCH -t 960:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +PROP= +sbatch \ + --export=ALL,DIFFUSION=${DIFFUSION},PROP=${PROP} \ + --job-name=train_qm9_classifier_${PROP}_${DIFFUSION} \ + train_qm9_classifier.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh +export NCCL_P2P_LEVEL=NVL +export HYDRA_FULL_ERROR=1 + +# Expecting: +# - DIFFUSION (absorbing_state or uniform) +# - PROP (qed or ring_count) +if [ -z "${DIFFUSION}" ]; then + echo "DIFFUSION is not set" + exit 1 +fi +if [ -z "${PROP}" ]; then + echo "PROP is not set" + exit 1 +fi +T=0 +RUN_NAME="${PROP}_${DIFFUSION}_T-${T}" + +# To enable preemption re-loading, set `hydra.run.dir` or +srun python -u -m main \ + mode=train_classifier \ + diffusion=${DIFFUSION} \ + T=${T} \ + data=qm9 \ + data.label_col="${PROP}" \ + data.label_col_pctile=90 \ + data.num_classes=2 \ + loader.global_batch_size=2048 \ + loader.eval_global_batch_size=4096 \ + classifier_backbone=dit \ + classifier_model=tiny-classifier \ + model.length=32 \ + optim.lr=3e-4 \ + lr_scheduler=cosine_decay_warmup \ + lr_scheduler.warmup_t=1000 \ + lr_scheduler.lr_min=3e-6 \ + callbacks.checkpoint_every_n_steps.every_n_train_steps=5_000 \ + callbacks.checkpoint_monitor.monitor=val/cross_entropy \ + trainer.val_check_interval=1.0 \ + trainer.max_steps=25_000 \ + wandb.group=train_classifier \ + wandb.name="qm9-classifier_${RUN_NAME}" \ + hydra.run.dir="${PWD}/outputs/qm9/classifier/${RUN_NAME}" \ No newline at end of file diff --git a/scripts/train_qm9_fudge_classifier.sh b/scripts/train_qm9_fudge_classifier.sh new file mode 100644 index 0000000000000000000000000000000000000000..bba5889839d7b3b755be45cc45833372cb519af1 --- /dev/null +++ b/scripts/train_qm9_fudge_classifier.sh @@ -0,0 +1,66 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=64000 # server memory requested (per node) +#SBATCH -t 960:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +sbatch \ + --export=ALL,PROP=${PROP} \ + --job-name=train_qm9_fudge_classifier_${PROP} \ + train_qm9_fudge_classifier.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh +export HYDRA_FULL_ERROR=1 +export NCCL_P2P_LEVEL=NVL + +# Expecting: +# - PROP (qed or ring_count) +if [ -z "${PROP}" ]; then + echo "PROP is not set" + exit 1 +fi +LABEL_SMOOTHING=FALSE +RUN_NAME="${PROP}" + +# To enable preemption re-loading, set `hydra.run.dir` or +srun python -u -m main \ + mode=train_classifier \ + +is_fudge_classifier=True \ + +use_label_smoothing=${LABEL_SMOOTHING} \ + parameterization=ar \ + data=qm9 \ + data.label_col="${PROP}" \ + data.label_col_pctile=90 \ + data.num_classes=2 \ + loader.global_batch_size=2048 \ + loader.eval_global_batch_size=4096 \ + classifier_model=tiny-classifier \ + classifier_backbone=dit \ + classifier_model.pooling=no_pooling \ + model.length=32 \ + optim.lr=3e-4 \ + lr_scheduler=cosine_decay_warmup \ + lr_scheduler.warmup_t=1000 \ + lr_scheduler.lr_min=3e-6 \ + training.guidance=null \ + +training.use_label_smoothing=${LABEL_SMOOTHING} \ + callbacks.checkpoint_every_n_steps.every_n_train_steps=5_000 \ + callbacks.checkpoint_monitor.monitor=val/cross_entropy \ + trainer.val_check_interval=1.0 \ + trainer.max_steps=25_000 \ + wandb.group=train_classifier \ + wandb.name="qm9-fudge_classifier_${RUN_NAME}" \ + hydra.run.dir="./outputs/qm9/fudge_classifier/${RUN_NAME}" diff --git a/scripts/train_qm9_guidance.sh b/scripts/train_qm9_guidance.sh new file mode 100644 index 0000000000000000000000000000000000000000..849e36b8fdc09511432f0b898adf12fc41414d70 --- /dev/null +++ b/scripts/train_qm9_guidance.sh @@ -0,0 +1,103 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=64000 # server memory requested (per node) +#SBATCH -t 960:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +PROP= +sbatch \ + --export=ALL,MODEL=${MODEL},PROP=${PROP} \ + --job-name=train_qm9_${PROP}_guidance_${MODEL} \ + train_qm9_guidance.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh +export NCCL_P2P_LEVEL=NVL +export HYDRA_FULL_ERROR=1 + +# Expecting: +# - MODEL (ar, mdlm, udlm) +# - PROP (qed or ring_count) +if [ -z "${MODEL}" ]; then + echo "MODEL is not set" + exit 1 +fi +if [ -z "${PROP}" ]; then + echo "PROP is not set" + exit 1 +fi +RUN_NAME="${MODEL}_${PROP}" + +if [ "${MODEL}" = "ar" ]; then + # AR + DIFFUSION="absorbing_state" + PARAMETERIZATION="ar" + T=0 + TIME_COND=False + ZERO_RECON_LOSS=False + sampling_use_cache=False +elif [ "${MODEL}" = "mdlm" ]; then + # MDLM + DIFFUSION="absorbing_state" + PARAMETERIZATION="subs" + T=0 + TIME_COND=False + ZERO_RECON_LOSS=False + sampling_use_cache=True +elif [ "${MODEL}" = "udlm" ]; then + # UDLM + DIFFUSION="uniform" + PARAMETERIZATION="d3pm" + T=0 + TIME_COND=True + ZERO_RECON_LOSS=True + sampling_use_cache=False +else + echo "MODEL must be one of ar, mdlm, udlm" + exit 1 +fi + +# To enable preemption re-loading, set `hydra.run.dir` or +srun python -u -m main \ + diffusion="${DIFFUSION}" \ + parameterization="${PARAMETERIZATION}" \ + T=${T} \ + time_conditioning=${TIME_COND} \ + zero_recon_loss=${ZERO_RECON_LOSS} \ + data=qm9 \ + data.label_col=${PROP} \ + data.label_col_pctile=90 \ + data.num_classes=2 \ + eval.generate_samples=True \ + loader.global_batch_size=2048 \ + loader.eval_global_batch_size=4096 \ + backbone="dit" \ + model=small \ + model.length=32 \ + optim.lr=3e-4 \ + lr_scheduler=cosine_decay_warmup \ + lr_scheduler.warmup_t=1000 \ + lr_scheduler.lr_min=3e-6 \ + training.guidance.cond_dropout=0.1 \ + callbacks.checkpoint_every_n_steps.every_n_train_steps=5_000 \ + training.compute_loss_on_pad_tokens=True \ + trainer.max_steps=25_000 \ + trainer.val_check_interval=1.0 \ + sampling.num_sample_batches=1 \ + sampling.batch_size=1 \ + sampling.use_cache=${sampling_use_cache} \ + sampling.steps=32 \ + wandb.name="qm9_${RUN_NAME}" \ + hydra.run.dir="${PWD}/outputs/qm9/${RUN_NAME}" diff --git a/scripts/train_qm9_no-guidance.sh b/scripts/train_qm9_no-guidance.sh new file mode 100644 index 0000000000000000000000000000000000000000..7e590641d0e63c5b856b9b2cc0b80510209df828 --- /dev/null +++ b/scripts/train_qm9_no-guidance.sh @@ -0,0 +1,99 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=32000 # server memory requested (per node) +#SBATCH -t 960:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +sbatch \ + --export=ALL,MODEL=${MODEL} \ + --job-name=train_qm9_no-guidance_${MODEL} \ + train_qm9_no-guidance.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh +export NCCL_P2P_LEVEL=NVL +export HYDRA_FULL_ERROR=1 + +# Expecting: +# - MODEL (ar, mdlm, udlm) +# - USE_SIMPLE_CE_LOSS (True, False; optional, default: False) +if [ -z "${MODEL}" ]; then + echo "MODEL is not set" + exit 1 +fi +if [ -z "${USE_SIMPLE_CE_LOSS}" ]; then + USE_SIMPLE_CE_LOSS=False +fi +RUN_NAME="${MODEL}_no-guidance" +if [ "${USE_SIMPLE_CE_LOSS}" = "True" ]; then + RUN_NAME="${RUN_NAME}_simple-ce" +fi + +if [ "${MODEL}" = "ar" ]; then + # AR + DIFFUSION="absorbing_state" + PARAMETERIZATION="ar" + T=0 + TIME_COND=False + ZERO_RECON_LOSS=False +elif [ "${MODEL}" = "mdlm" ]; then + # MDLM + DIFFUSION="absorbing_state" + PARAMETERIZATION="subs" + T=0 + TIME_COND=False + ZERO_RECON_LOSS=False +elif [ "${MODEL}" = "udlm" ]; then + # UDLM + DIFFUSION="uniform" + PARAMETERIZATION="d3pm" + T=0 + TIME_COND=True + ZERO_RECON_LOSS=True +else + echo "MODEL must be one of ar, mdlm, udlm" + exit 1 +fi + +# To enable preemption re-loading, set `hydra.run.dir` or +# `checkpointing.save_dir` explicitly. +srun python -u -m main \ + diffusion="${DIFFUSION}" \ + parameterization="${PARAMETERIZATION}" \ + T=${T} \ + time_conditioning=${TIME_COND} \ + zero_recon_loss=${ZERO_RECON_LOSS} \ + data=qm9 \ + data.label_col=null \ + data.label_col_pctile=null \ + data.num_classes=null \ + eval.generate_samples=False \ + loader.global_batch_size=2048 \ + loader.eval_global_batch_size=4096 \ + backbone="dit" \ + model=small \ + model.length=32 \ + optim.lr=3e-4 \ + lr_scheduler=cosine_decay_warmup \ + lr_scheduler.warmup_t=1000 \ + lr_scheduler.lr_min=3e-6 \ + training.guidance=null \ + callbacks.checkpoint_every_n_steps.every_n_train_steps=5_000 \ + training.compute_loss_on_pad_tokens=True \ + training.use_simple_ce_loss=${USE_SIMPLE_CE_LOSS} \ + trainer.max_steps=25_000 \ + trainer.val_check_interval=1.0 \ + wandb.name="qm9_${RUN_NAME}" \ + hydra.run.dir="${PWD}/outputs/qm9/${RUN_NAME}" diff --git a/scripts/train_qm9_pplm_classifier.sh b/scripts/train_qm9_pplm_classifier.sh new file mode 100644 index 0000000000000000000000000000000000000000..a6b8f6ca8602ab1d30a631acd356634a7fbdcb87 --- /dev/null +++ b/scripts/train_qm9_pplm_classifier.sh @@ -0,0 +1,108 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=64000 # server memory requested (per node) +#SBATCH -t 960:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +PROP= +sbatch \ + --export=ALL,MODEL=${MODEL},PROP=${PROP} \ + --job-name=train_qm9_pplm_classifier_${PROP}_${MODEL} \ + train_qm9_pplm_classifier.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh +export HYDRA_FULL_ERROR=1 +export NCCL_P2P_LEVEL=NVL + +# Expecting: +# - MODEL (ar, mdlm, or udlm) +# - PROP (qed or ring_count) +if [ -z "${MODEL}" ]; then + echo "MODEL is not set" + exit 1 +fi +if [ -z "${PROP}" ]; then + echo "PROP is not set" + exit 1 +fi +LABEL_SMOOTHING=FALSE +RUN_NAME="${PROP}_${MODEL}" + +if [ "${MODEL}" = "ar" ]; then + # AR + PARAMETERIZATION="ar" + PRETRAINED_PATH="${PWD}/outputs/qm9/${MODEL}_no-guidance/checkpoints/best.ckpt" + POOLING="attention_mean" + # dummy properties + DIFFUSION="absorbing_state" + T=0 + TIME_COND=False +elif [ "${MODEL}" = "mdlm" ]; then + # MDLM + DIFFUSION="absorbing_state" + PARAMETERIZATION="subs" + T=0 + TIME_COND=False + PRETRAINED_PATH="${PWD}/outputs/qm9/${MODEL}_no-guidance/checkpoints/best.ckpt" + POOLING="mean" +elif [ "${MODEL}" = "udlm" ]; then + # UDLM + DIFFUSION="uniform" + PARAMETERIZATION="d3pm" + T=0 + TIME_COND=True + PRETRAINED_PATH="${PWD}/outputs/qm9/${MODEL}_no-guidance/checkpoints/best.ckpt" + POOLING="mean" +else + echo "MODEL must be one of ar, mdlm, udlm" + exit 1 +fi + +# To enable preemption re-loading, set `hydra.run.dir` or +srun python -u -m main \ + mode=train_classifier \ + +is_pplm_classifier=True \ + +use_label_smoothing=${LABEL_SMOOTHING} \ + eval.checkpoint_path="${PRETRAINED_PATH}" \ + parameterization=${PARAMETERIZATION} \ + time_conditioning=${TIME_COND} \ + diffusion=${DIFFUSION} \ + T=${T} \ + data=qm9 \ + data.label_col="${PROP}" \ + data.label_col_pctile=90 \ + data.num_classes=2 \ + loader.global_batch_size=2048 \ + loader.eval_global_batch_size=4096 \ + classifier_model=small-classifier \ + classifier_backbone=dit \ + classifier_model.pooling=${POOLING} \ + model.length=32 \ + +classifier_model.freeze_encoder=True \ + +classifier_model.use_encoder_ema=True \ + optim.lr=3e-5 \ + lr_scheduler=cosine_decay_warmup \ + lr_scheduler.warmup_t=1000 \ + lr_scheduler.lr_min=3e-7 \ + training.guidance=null \ + +training.use_label_smoothing=${LABEL_SMOOTHING} \ + callbacks.checkpoint_every_n_steps.every_n_train_steps=5_000 \ + callbacks.checkpoint_monitor.monitor=val/cross_entropy \ + trainer.val_check_interval=1.0 \ + trainer.max_steps=25_000 \ + wandb.group=train_classifier \ + wandb.name="qm9-pplm_classifier_${RUN_NAME}" \ + hydra.run.dir="./outputs/qm9/pplm_classifier/${RUN_NAME}" diff --git a/scripts/train_ten_species_classifier.sh b/scripts/train_ten_species_classifier.sh new file mode 100644 index 0000000000000000000000000000000000000000..8e181fb812074ec79d8a86890238a217108a01db --- /dev/null +++ b/scripts/train_ten_species_classifier.sh @@ -0,0 +1,67 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=64000 # server memory requested (per node) +#SBATCH -t 960:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +sbatch \ + --export=ALL,DIFFUSION=${DIFFUSION} \ + --job-name=train_ten_species_classifier_${DIFFUSION} \ + train_ten_species_classifier.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh +export NCCL_P2P_LEVEL=NVL +export HYDRA_FULL_ERROR=1 + +# Expecting: +# - DIFFUSION (absorbing_state or uniform) +if [ -z "${DIFFUSION}" ]; then + echo "DIFFUSION is not set" + exit 1 +fi +T=0 +RUN_NAME="${DIFFUSION}_T-${T}" + +# To enable preemption re-loading, set `hydra.run.dir` or +srun python -u -m main \ + mode=train_classifier \ + diffusion=${DIFFUSION} \ + T=${T} \ + data=ten_species \ + loader.global_batch_size=32 \ + loader.eval_global_batch_size=64 \ + classifier_backbone=dimamba \ + classifier_model=tiny-dimamba-classifier \ + classifier_model.bidirectional=True \ + classifier_model.bidirectional_strategy=add \ + classifier_model.bidirectional_weight_tie=True \ + model=dimamba \ + backbone=dimamba \ + model.length=32768 \ + model.bidirectional=True \ + model.bidirectional_strategy=add \ + model.bidirectional_weight_tie=True \ + optim.lr=2e-3 \ + lr_scheduler=cosine_decay_warmup \ + lr_scheduler.warmup_t=3000 \ + lr_scheduler.lr_min=2e-6 \ + callbacks.checkpoint_every_n_steps.every_n_train_steps=6_000 \ + callbacks.checkpoint_monitor.monitor=val/cross_entropy \ + trainer.val_check_interval=3_000 \ + trainer.max_steps=30_000 \ + wandb.group=train_classifier \ + wandb.name="ten_species-classifier_${RUN_NAME}" \ + hydra.run.dir="${PWD}/outputs/ten_species/classifier/${RUN_NAME}" diff --git a/scripts/train_ten_species_eval_classifier.sh b/scripts/train_ten_species_eval_classifier.sh new file mode 100644 index 0000000000000000000000000000000000000000..c005c362b5f67cde481ad86e3bb85332a8d9d507 --- /dev/null +++ b/scripts/train_ten_species_eval_classifier.sh @@ -0,0 +1,53 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=32000 # server memory requested (per node) +#SBATCH -t 960:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=8 +#SBATCH --gres=gpu:8 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +sbatch \ + --export=ALL,MODEL=${MODEL} \ + --job-name=train_ten_species_guidance_${MODEL} \ + train_ten_species_guidance.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh +export NCCL_P2P_LEVEL=NVL +export HYDRA_FULL_ERROR=1 + +# Expecting: +# - MODEL (ar, mdlm, udlm) +if [ -z "${MODEL}" ]; then + echo "MODEL is not set" + exit 1 +fi +if [ -z "${USE_SIMPLE_CE_LOSS}" ]; then + USE_SIMPLE_CE_LOSS=False +fi +RUN_NAME="${MODEL}" +if [ "${USE_SIMPLE_CE_LOSS}" = "True" ]; then + RUN_NAME="${RUN_NAME}_simple-ce" +fi + +if [ "${MODEL}" = "ar" ]; then + # AR + DIFFUSION="absorbing_state" + PARAMETERIZATION="ar" + T=0 + TIME_COND=False + ZERO_RECON_LOSS=False + BIDIRECTIONAL=False + BIDIRECTIONAL_STRATEGY=null + BIDIRECTIONAL_WEIGHT_TIE=null +elif [ "${MODEL}" = "mdlm" ]; then + # MDLM + DIFFUSION="absorbing_state" + PARAMETERIZATION="subs" + T=0 + TIME_COND=False + ZERO_RECON_LOSS=False + BIDIRECTIONAL=True + BIDIRECTIONAL_STRATEGY=add + BIDIRECTIONAL_WEIGHT_TIE=True +elif [ "${MODEL}" = "udlm" ]; then + # UDLM + DIFFUSION="uniform" + PARAMETERIZATION="d3pm" + T=0 + TIME_COND=True + ZERO_RECON_LOSS=True + BIDIRECTIONAL=True + BIDIRECTIONAL_STRATEGY=add + BIDIRECTIONAL_WEIGHT_TIE=True +else + echo "MODEL must be one of ar, mdlm, udlm" + exit 1 +fi + +# To enable preemption re-loading, set `hydra.run.dir` or +srun python -u -m main \ + diffusion=${DIFFUSION} \ + parameterization=${PARAMETERIZATION} \ + T=${T} \ + time_conditioning=${TIME_COND} \ + zero_recon_loss=${ZERO_RECON_LOSS} \ + data=ten_species \ + eval.generate_samples=False \ + loader.global_batch_size=32 \ + loader.eval_global_batch_size=64 \ + loader.batch_size=2 \ + backbone=dimamba \ + model=dimamba \ + model.bidirectional=${BIDIRECTIONAL} \ + model.bidirectional_strategy=${BIDIRECTIONAL_STRATEGY} \ + model.bidirectional_weight_tie=${BIDIRECTIONAL_WEIGHT_TIE} \ + model.length=32768 \ + optim.lr=2e-3 \ + lr_scheduler=cosine_decay_warmup \ + lr_scheduler.warmup_t=3000 \ + lr_scheduler.lr_min=2e-6 \ + training.guidance.cond_dropout=0.1 \ + training.compute_loss_on_pad_tokens=False \ + training.use_simple_ce_loss=${USE_SIMPLE_CE_LOSS} \ + callbacks.checkpoint_every_n_steps.every_n_train_steps=6_000 \ + trainer.max_steps=30_000 \ + trainer.val_check_interval=3_000 \ + wandb.name="ten_species_${RUN_NAME}" \ + hydra.run.dir="./outputs/ten_species/${RUN_NAME}" diff --git a/scripts/train_ten_species_no-guidance.sh b/scripts/train_ten_species_no-guidance.sh new file mode 100644 index 0000000000000000000000000000000000000000..1597586fc92af57d80d29cfc6d8735d1c5a5149c --- /dev/null +++ b/scripts/train_ten_species_no-guidance.sh @@ -0,0 +1,108 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=64000 # server memory requested (per node) +#SBATCH -t 960:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +sbatch \ + --export=ALL,MODEL=${MODEL} \ + --job-name=train_ten_species_no-guidance_${MODEL} \ + train_ten_species_no-guidance.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh +export NCCL_P2P_LEVEL=NVL +export HYDRA_FULL_ERROR=1 + +# Expecting: +# - MODEL (ar, mdlm, udlm) +# - USE_SIMPLE_CE_LOSS (True, False; optional, default: False) +if [ -z "${MODEL}" ]; then + echo "MODEL is not set" + exit 1 +fi +if [ -z "${USE_SIMPLE_CE_LOSS}" ]; then + USE_SIMPLE_CE_LOSS=False +fi +RUN_NAME="${MODEL}_no-guidance" +if [ "${USE_SIMPLE_CE_LOSS}" = "True" ]; then + RUN_NAME="${RUN_NAME}_simple-ce" +fi + +if [ "${MODEL}" = "ar" ]; then + # AR + DIFFUSION="absorbing_state" + PARAMETERIZATION="ar" + T=0 + TIME_COND=False + ZERO_RECON_LOSS=False + BIDIRECTIONAL=False + BIDIRECTIONAL_STRATEGY=null + BIDIRECTIONAL_WEIGHT_TIE=null +elif [ "${MODEL}" = "mdlm" ]; then + # MDLM + DIFFUSION="absorbing_state" + PARAMETERIZATION="subs" + T=0 + TIME_COND=False + ZERO_RECON_LOSS=False + BIDIRECTIONAL=True + BIDIRECTIONAL_STRATEGY=add + BIDIRECTIONAL_WEIGHT_TIE=True +elif [ "${MODEL}" = "udlm" ]; then + # UDLM + DIFFUSION="uniform" + PARAMETERIZATION="d3pm" + T=0 + TIME_COND=True + ZERO_RECON_LOSS=True + BIDIRECTIONAL=True + BIDIRECTIONAL_STRATEGY=add + BIDIRECTIONAL_WEIGHT_TIE=True +else + echo "MODEL must be one of ar, mdlm, udlm" + exit 1 +fi + +# To enable preemption re-loading, set `hydra.run.dir` or +srun python -u -m main \ + diffusion=${DIFFUSION} \ + parameterization=${PARAMETERIZATION} \ + T=${T} \ + time_conditioning=${TIME_COND} \ + zero_recon_loss=${ZERO_RECON_LOSS} \ + data=ten_species \ + eval.generate_samples=False \ + loader.global_batch_size=32 \ + loader.eval_global_batch_size=64 \ + loader.batch_size=2 \ + backbone=dimamba \ + model=dimamba \ + model.bidirectional=${BIDIRECTIONAL} \ + model.bidirectional_strategy=${BIDIRECTIONAL_STRATEGY} \ + model.bidirectional_weight_tie=${BIDIRECTIONAL_WEIGHT_TIE} \ + model.length=32768 \ + optim.lr=2e-3 \ + lr_scheduler=cosine_decay_warmup \ + lr_scheduler.warmup_t=3000 \ + lr_scheduler.lr_min=2e-6 \ + training.guidance=null \ + training.compute_loss_on_pad_tokens=False \ + training.use_simple_ce_loss=${USE_SIMPLE_CE_LOSS} \ + callbacks.checkpoint_every_n_steps.every_n_train_steps=6_000 \ + trainer.max_steps=30_000 \ + trainer.val_check_interval=3_000 \ + wandb.name="ten_species_${RUN_NAME}" \ + hydra.run.dir="./outputs/ten_species/${RUN_NAME}" \ No newline at end of file diff --git a/scripts/train_ten_species_pplm_classifier.sh b/scripts/train_ten_species_pplm_classifier.sh new file mode 100644 index 0000000000000000000000000000000000000000..24516503d71025cc1b8dc6ccc72ded6127516ed4 --- /dev/null +++ b/scripts/train_ten_species_pplm_classifier.sh @@ -0,0 +1,116 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=64000 # server memory requested (per node) +#SBATCH -t 960:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +sbatch \ + --export=ALL,MODEL=${MODEL} \ + --job-name=train_ten_species_pplm_classifier_${MODEL} \ + train_ten_species_pplm_classifier.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh +export HYDRA_FULL_ERROR=1 +export NCCL_P2P_LEVEL=NVL + +# Expecting: +# - MODEL (ar, mdlm, or udlm) +if [ -z "${MODEL}" ]; then + echo "MODEL is not set" + exit 1 +fi +LABEL_SMOOTHING=FALSE +RUN_NAME="${MODEL}_lr-2e-3" + +if [ "${MODEL}" = "ar" ]; then + # AR + PARAMETERIZATION="ar" + PRETRAINED_PATH="${PWD}/outputs/ten_species/${MODEL}_no-guidance/checkpoints/best.ckpt" + POOLING="attention_mean" + # dummy properties + DIFFUSION="absorbing_state" + T=0 + TIME_COND=False + BIDIRECTIONAL=False + BIDIRECTIONAL_STRATEGY=null + BIDIRECTIONAL_WEIGHT_TIE=null +elif [ "${MODEL}" = "mdlm" ]; then + # MDLM + DIFFUSION="absorbing_state" + PARAMETERIZATION="subs" + T=0 + TIME_COND=False + PRETRAINED_PATH="${PWD}/outputs/ten_species/${MODEL}_no-guidance/checkpoints/best.ckpt" + POOLING="mean" + BIDIRECTIONAL=True + BIDIRECTIONAL_STRATEGY=add + BIDIRECTIONAL_WEIGHT_TIE=True +elif [ "${MODEL}" = "udlm" ]; then + # UDLM + DIFFUSION="uniform" + PARAMETERIZATION="d3pm" + T=0 + TIME_COND=True + PRETRAINED_PATH="${PWD}/outputs/ten_species/${MODEL}_no-guidance/checkpoints/best.ckpt" + POOLING="mean" + BIDIRECTIONAL=True + BIDIRECTIONAL_STRATEGY=add + BIDIRECTIONAL_WEIGHT_TIE=True +else + echo "MODEL must be one of ar, mdlm, udlm" + exit 1 +fi + +# To enable preemption re-loading, set `hydra.run.dir` or +srun python -u -m main \ + mode=train_classifier \ + +is_pplm_classifier=True \ + +use_label_smoothing=${LABEL_SMOOTHING} \ + eval.checkpoint_path="${PRETRAINED_PATH}" \ + parameterization=${PARAMETERIZATION} \ + time_conditioning=${TIME_COND} \ + diffusion=${DIFFUSION} \ + T=${T} \ + data=ten_species \ + loader.global_batch_size=32 \ + loader.eval_global_batch_size=64 \ + model=dimamba \ + backbone=dimamba \ + model.bidirectional=${BIDIRECTIONAL} \ + model.bidirectional_strategy=${BIDIRECTIONAL_STRATEGY} \ + model.bidirectional_weight_tie=${BIDIRECTIONAL_WEIGHT_TIE} \ + model.length=32768 \ + classifier_model=dimamba-classifier \ + classifier_backbone=dimamba \ + classifier_model.pooling=${POOLING} \ + classifier_model.bidirectional=${BIDIRECTIONAL} \ + classifier_model.bidirectional_strategy=${BIDIRECTIONAL_STRATEGY} \ + classifier_model.bidirectional_weight_tie=${BIDIRECTIONAL_WEIGHT_TIE} \ + +classifier_model.freeze_encoder=True \ + +classifier_model.use_encoder_ema=True \ + optim.lr=2e-3 \ + lr_scheduler=cosine_decay_warmup \ + lr_scheduler.warmup_t=3000 \ + lr_scheduler.lr_min=2e-6 \ + training.guidance=null \ + +training.use_label_smoothing=${LABEL_SMOOTHING} \ + callbacks.checkpoint_every_n_steps.every_n_train_steps=6_000 \ + callbacks.checkpoint_monitor.monitor=val/cross_entropy \ + trainer.val_check_interval=3_000 \ + trainer.max_steps=30_000 \ + wandb.group=train_classifier \ + wandb.name="ten_species-pplm_classifier_${RUN_NAME}" \ + hydra.run.dir="./outputs/ten_species/pplm_classifier/${RUN_NAME}" diff --git a/scripts/train_text8.sh b/scripts/train_text8.sh new file mode 100644 index 0000000000000000000000000000000000000000..4b70f0657415ce8a20420eaea55f1c5a5e531797 --- /dev/null +++ b/scripts/train_text8.sh @@ -0,0 +1,97 @@ +#!/bin/bash +#SBATCH -o ../watch_folder/%x_%j.out # output file (%j expands to jobID) +#SBATCH -N 1 # Total number of nodes requested +#SBATCH --get-user-env # retrieve the users login environment +#SBATCH --mem=64000 # server memory requested (per node) +#SBATCH -t 960:00:00 # Time limit (hh:mm:ss) +#SBATCH --constraint="[a100|a6000|a5000|3090]" +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 # Type/number of GPUs needed +#SBATCH --open-mode=append # Do not overwrite logs +#SBATCH --requeue # Requeue upon preemption + +< +sbatch \ + --export=ALL,MODEL=${MODEL} \ + --job-name=train_text8_${MODEL} \ + train_text8.sh +comment + +# Setup environment +cd ../ || exit # Go to the root directory of the repo +source setup_env.sh +export NCCL_P2P_LEVEL=NVL +export HYDRA_FULL_ERROR=1 + +# Expecting: +# - MODEL (ar, mdlm, udlm) +if [ -z "${MODEL}" ]; then + echo "MODEL is not set" + exit 1 +fi +RUN_NAME="${MODEL}" + +if [ "${MODEL}" = "ar" ]; then + # AR + DIFFUSION="absorbing_state" + PARAMETERIZATION="ar" + T=0 + TIME_COND=False + ZERO_RECON_LOSS=False + sampling_use_cache=False +elif [ "${MODEL}" = "mdlm" ]; then + # MDLM + DIFFUSION="absorbing_state" + PARAMETERIZATION="subs" + T=0 + TIME_COND=False + ZERO_RECON_LOSS=False + sampling_use_cache=True +elif [ "${MODEL}" = "udlm" ]; then + # UDLM + DIFFUSION="uniform" + PARAMETERIZATION="d3pm" + T=0 + TIME_COND=True + ZERO_RECON_LOSS=True + sampling_use_cache=False +else + echo "MODEL must be one of ar, mdlm, udlm" + exit 1 +fi + + +# To enable preemption re-loading, set `hydra.run.dir` or +# `checkpointing.save_dir` explicitly. +srun python -u -m main \ + diffusion="${DIFFUSION}" \ + parameterization="${PARAMETERIZATION}" \ + T=${T} \ + time_conditioning=${TIME_COND} \ + zero_recon_loss=${ZERO_RECON_LOSS} \ + data="text8" \ + data.wrap=True \ + data.tokenizer_name_or_path=text8 \ + loader.global_batch_size=512 \ + loader.eval_global_batch_size=1024 \ + backbone="dit" \ + model=small \ + model.length=256 \ + optim.lr=3e-4 \ + training.guidance=null \ + callbacks.checkpoint_every_n_steps.every_n_train_steps=100_000 \ + trainer.log_every_n_steps=100 \ + trainer.max_steps=1_000_000 \ + trainer.precision=bf16 \ + trainer.val_check_interval=5_000 \ + +trainer.check_val_every_n_epoch=null \ + eval.generate_samples=True \ + sampling.num_sample_batches=1 \ + sampling.batch_size=2 \ + sampling.use_cache=${sampling_use_cache} \ + sampling.steps=256 \ + wandb.name="text8_${RUN_NAME}" \ + hydra.run.dir="${PWD}/outputs/text8/${RUN_NAME}"