diff --git a/.gitignore b/.gitignore index 891c83b36247cfa9d8ac525fd1c1eee4918b4ced..1a84dff171766ba48b1d71654dc7589011b37077 100644 --- a/.gitignore +++ b/.gitignore @@ -29,4 +29,9 @@ dpacman/nohup.out dpacman/*/__pycache__/ dpacman/data_tasks/split/__pycache__/ dpacman/data_tasks/cluster/__pycache__/ -dpacman/data_tasks/embeddings/__pycache__/ \ No newline at end of file +dpacman/data_tasks/embeddings/__pycache__/ +dpacman/combine_shards.py +dpacman/combine.log +dpacman/loss_sim.py +dpacman/loss_temp.py +dpacman/peak_examples/ \ No newline at end of file diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..830819099811541e730c74779b1888fceb85d433 --- /dev/null +++ b/configs/callbacks/default.yaml @@ -0,0 +1,21 @@ +defaults: + - model_checkpoint + - early_stopping + - model_summary + - _self_ + +model_checkpoint: + dirpath: ${paths.output_dir}/checkpoints + filename: "epoch_{epoch:03d}" + monitor: "val/loss" + mode: "min" + save_last: True + auto_insert_metric_name: False + +early_stopping: + monitor: "val/loss" + patience: 100 + mode: "min" + +model_summary: + max_depth: -1 diff --git a/configs/callbacks/early_stopping.yaml b/configs/callbacks/early_stopping.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c826c8d58651a5e2c7cca0e99948a9b6ccabccf3 --- /dev/null +++ b/configs/callbacks/early_stopping.yaml @@ -0,0 +1,15 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html + +early_stopping: + _target_: lightning.pytorch.callbacks.EarlyStopping + monitor: ??? # quantity to be monitored, must be specified !!! + min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement + patience: 3 # number of checks with no improvement after which training will be stopped + verbose: False # verbosity mode + mode: "min" # "max" means higher metric value is better, can be also "min" + strict: True # whether to crash the training if monitor is not found in the validation metrics + check_finite: True # when set True, stops training when the monitor becomes NaN or infinite + stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold + divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold + check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch + # log_rank_zero_only: False # this keyword argument isn't available in stable version diff --git a/configs/callbacks/model_checkpoint.yaml b/configs/callbacks/model_checkpoint.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bf946e88b1ecfaf96efa91428e4f38e17267b25f --- /dev/null +++ b/configs/callbacks/model_checkpoint.yaml @@ -0,0 +1,17 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html + +model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: null # directory to save the model file + filename: null # checkpoint filename + monitor: null # name of the logged metric which determines when model is improving + verbose: False # verbosity mode + save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 1 # save k best models (determined by above metric) + mode: "min" # "max" means higher metric value is better, can be also "min" + auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name + save_weights_only: False # if True, then only the model’s weights will be saved + every_n_train_steps: null # number of training steps between checkpoints + train_time_interval: null # checkpoints are monitored at the specified time interval + every_n_epochs: null # number of epochs between checkpoints + save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation diff --git a/configs/callbacks/model_summary.yaml b/configs/callbacks/model_summary.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b75981d8cd5d73f61088d80495dc540274bca3d1 --- /dev/null +++ b/configs/callbacks/model_summary.yaml @@ -0,0 +1,5 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html + +model_summary: + _target_: lightning.pytorch.callbacks.RichModelSummary + max_depth: 1 # the maximum depth of layer nesting that the summary will include diff --git a/dpacman/classifier/model/__init__.py b/configs/callbacks/none.yaml similarity index 100% rename from dpacman/classifier/model/__init__.py rename to configs/callbacks/none.yaml diff --git a/configs/callbacks/rich_progress_bar.yaml b/configs/callbacks/rich_progress_bar.yaml new file mode 100644 index 0000000000000000000000000000000000000000..de6f1ccb11205a4db93645fb6f297e50205de172 --- /dev/null +++ b/configs/callbacks/rich_progress_bar.yaml @@ -0,0 +1,4 @@ +# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html + +rich_progress_bar: + _target_: lightning.pytorch.callbacks.RichProgressBar diff --git a/configs/data_module/pair.yaml b/configs/data_module/pair.yaml new file mode 100644 index 0000000000000000000000000000000000000000..67f3a417ddfaa03158d5077dc1fdb6fd5a16eb69 --- /dev/null +++ b/configs/data_module/pair.yaml @@ -0,0 +1,13 @@ +_target_: dpacman.data_modules.pair.PairDataModule + +train_file: data_files/processed/splits/by_dna/babytrain.csv +val_file: data_files/processed/splits/by_dna/babyval.csv +test_file: data_files/processed/splits/by_dna/babytest.csv + +tr_shelf_path: data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf +dna_shelf_path: data_files/processed/embeddings/fimo_hits_only/baby_peaks_segmentnt_pernuc_with_onehot.shelf + +batch_size: 32 +num_workers: 8 + +maximize_num_workers: False \ No newline at end of file diff --git a/configs/data_modules/pair.yaml b/configs/data_modules/pair.yaml deleted file mode 100644 index 3f8d9446295170f47de680175d8c48614ac0f926..0000000000000000000000000000000000000000 --- a/configs/data_modules/pair.yaml +++ /dev/null @@ -1,9 +0,0 @@ - -train_file: data_files/splits/train.csv -val_file: data_files/splits/val.csv -test_file: data_files/splits/test.csv - -batch_size: 32 -num_workers: 8 - -maximize_num_workers: False \ No newline at end of file diff --git a/configs/data_task/cluster/remap.yaml b/configs/data_task/cluster/remap.yaml index 6bb042a6c0da69c6f10be1ae4f0492db4f9d788c..9406ed2300e3bfda9b40c5d9cfedb5c8a51f2f86 100644 --- a/configs/data_task/cluster/remap.yaml +++ b/configs/data_task/cluster/remap.yaml @@ -1,5 +1,5 @@ name: remap -type: cluster +task_type: cluster max_protein_length: 1998 diff --git a/configs/data_task/download/genome.yaml b/configs/data_task/download/genome.yaml index a67cc3d99ec2fb141b9d9026cb507f3e7771c4f9..5ad60bae74aed1a1dbe85ba4c480bf7bfe064250 100644 --- a/configs/data_task/download/genome.yaml +++ b/configs/data_task/download/genome.yaml @@ -1,5 +1,5 @@ name: genome -type: download +task_type: download output_dir: dpacman/data_files/raw/genomes genomes: - hg38 \ No newline at end of file diff --git a/configs/data_task/download/remap.yaml b/configs/data_task/download/remap.yaml index 9858326fb2fd0d89ceee331060ba66619b3b6c15..c7d295fa391f70aead03c4341c6d8b9f49e14522 100644 --- a/configs/data_task/download/remap.yaml +++ b/configs/data_task/download/remap.yaml @@ -1,5 +1,5 @@ name: remap -type: download +task_type: download nr_url: https://remap.univ-amu.fr/storage/remap2022/hg38/MACS2/remap2022_nr_macs2_hg38_v1_0.bed.gz nr_output_dir: dpacman/data_files/raw/remap diff --git a/configs/data_task/embeddings/dna.yaml b/configs/data_task/embeddings/dna.yaml index ec20960f4424e1cfa0d7693d1e4f1255cdac9d65..dd2c19abf2074d6a4a32015a7ec10274298e5ee3 100644 --- a/configs/data_task/embeddings/dna.yaml +++ b/configs/data_task/embeddings/dna.yaml @@ -1,9 +1,13 @@ name: dna -type: embeddings +task_type: embeddings genome_json_dir: null -chrom_model: caduceus -input_file: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/dna_seqid_to_dna_sequence.json +chrom_model: segmentnt +input_file: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/dna_seqid_to_dna_sequence_with_rc.json out_dir: dpacman/data_files/processed/embeddings/fimo_hits_only -device: gpu \ No newline at end of file +device: gpu + +batch_size: 1 + +debug: false \ No newline at end of file diff --git a/configs/data_task/embeddings/protein.yaml b/configs/data_task/embeddings/protein.yaml index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..c8fcc6ab8dc8362071b63ce7cf68f36c26cda9aa 100644 --- a/configs/data_task/embeddings/protein.yaml +++ b/configs/data_task/embeddings/protein.yaml @@ -0,0 +1,14 @@ +name: protein +task_type: embeddings + +prot_model: esm +input_file: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/tr_seqid_to_tr_sequence.json +out_dir: dpacman/data_files/processed/embeddings/fimo_hits_only + +device: gpu + +save_as_shelf: true + +batch_size: 1 + +debug: false \ No newline at end of file diff --git a/configs/data_task/fimo/post_fimo.yaml b/configs/data_task/fimo/post_fimo.yaml index b0e028a45103859b0691a2475729eabbe10ed81a..08ede7b96b368c70416f7bf304a2fcc587622c3c 100644 --- a/configs/data_task/fimo/post_fimo.yaml +++ b/configs/data_task/fimo/post_fimo.yaml @@ -1,5 +1,5 @@ name: post_fimo -type: fimo +task_type: fimo fimo_out_dir: dpacman/data_files/processed/fimo/fimo_out_q processed_output_csv: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/remap2022_crm_fimo_output_q_processed.csv diff --git a/configs/data_task/fimo/pre_fimo.yaml b/configs/data_task/fimo/pre_fimo.yaml index 63f4646e0147bc56e94d1f725c316e5d6668f8aa..c9e48a87d32aa8b23f615c2490d240f1f91d572c 100644 --- a/configs/data_task/fimo/pre_fimo.yaml +++ b/configs/data_task/fimo/pre_fimo.yaml @@ -1,5 +1,5 @@ name: pre_fimo -type: fimo +task_type: fimo paths: input_csv: dpacman/data_files/processed/remap/remap2022_crm_macs2_hg38_v1_0_clean.tsv diff --git a/configs/data_task/fimo/run_fimo.yaml b/configs/data_task/fimo/run_fimo.yaml index 01902d0fcf4cd77ed5099762615864fa2d1dea7a..fcfa20026ba21efdb57fb9d9e9a6b71618687f50 100644 --- a/configs/data_task/fimo/run_fimo.yaml +++ b/configs/data_task/fimo/run_fimo.yaml @@ -1,5 +1,5 @@ name: run_fimo -type: fimo +task_type: fimo debug: true diff --git a/configs/data_task/split/remap.yaml b/configs/data_task/split/remap.yaml index 77a345dc933f315d85843aec5685f5a22ba768fa..255ee48923e01c92f5b5d5376947b7ea935dee71 100644 --- a/configs/data_task/split/remap.yaml +++ b/configs/data_task/split/remap.yaml @@ -10,7 +10,10 @@ cluster_output_paths: input_data_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/remap2022_crm_fimo_output_q_processed_seed0.parquet split_out_dir: dpacman/data_files/processed/splits +dna_map_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/dna_seqid_to_dna_sequence.json + split_by: both # protein, dna, or both +augment_rc: true test_ratio: 0.10 val_ratio: 0.10 diff --git a/configs/extras/default.yaml b/configs/extras/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b9c6b622283a647fbc513166fc14f016cc3ed8a0 --- /dev/null +++ b/configs/extras/default.yaml @@ -0,0 +1,8 @@ +# disable python warnings if they annoy you +ignore_warnings: False + +# ask user for tags if none are provided in the config +enforce_tags: True + +# pretty print config tree at the start of the run using Rich library +print_config: True diff --git a/configs/logger/aim.yaml b/configs/logger/aim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8f9f6adad7feb2780c2efd5ddb0ed053621e05f8 --- /dev/null +++ b/configs/logger/aim.yaml @@ -0,0 +1,28 @@ +# https://aimstack.io/ + +# example usage in lightning module: +# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py + +# open the Aim UI with the following command (run in the folder containing the `.aim` folder): +# `aim up` + +aim: + _target_: aim.pytorch_lightning.AimLogger + repo: ${paths.root_dir} # .aim folder will be created here + # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# + + # aim allows to group runs under experiment name + experiment: null # any string, set to "default" if not specified + + train_metric_prefix: "train/" + val_metric_prefix: "val/" + test_metric_prefix: "test/" + + # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) + system_tracking_interval: 10 # set to null to disable system metrics tracking + + # enable/disable logging of system params such as installed packages, git info, env vars, etc. + log_system_params: true + + # enable/disable tracking console logs (default value is true) + capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 diff --git a/configs/logger/comet.yaml b/configs/logger/comet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e0789274e2137ee6c97ca37a5d56c2b8abaf0aaa --- /dev/null +++ b/configs/logger/comet.yaml @@ -0,0 +1,12 @@ +# https://www.comet.ml + +comet: + _target_: lightning.pytorch.loggers.comet.CometLogger + api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable + save_dir: "${paths.output_dir}" + project_name: "lightning-hydra-template" + rest_api_key: null + # experiment_name: "" + experiment_key: null # set to resume experiment + offline: False + prefix: "" diff --git a/configs/logger/csv.yaml b/configs/logger/csv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fa028e9c146430c319101ffdfce466514338591c --- /dev/null +++ b/configs/logger/csv.yaml @@ -0,0 +1,7 @@ +# csv logger built in lightning + +csv: + _target_: lightning.pytorch.loggers.csv_logs.CSVLogger + save_dir: "${paths.output_dir}" + name: "csv/" + prefix: "" diff --git a/configs/logger/many_loggers.yaml b/configs/logger/many_loggers.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dd586800bdccb4e8f4b0236a181b7ddd756ba9ab --- /dev/null +++ b/configs/logger/many_loggers.yaml @@ -0,0 +1,9 @@ +# train with many loggers at once + +defaults: + # - comet + - csv + # - mlflow + # - neptune + - tensorboard + - wandb diff --git a/configs/logger/mlflow.yaml b/configs/logger/mlflow.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f8fb7e685fa27fc8141387a421b90a0b9b492d9e --- /dev/null +++ b/configs/logger/mlflow.yaml @@ -0,0 +1,12 @@ +# https://mlflow.org + +mlflow: + _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger + # experiment_name: "" + # run_name: "" + tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI + tags: null + # save_dir: "./mlruns" + prefix: "" + artifact_location: null + # run_id: "" diff --git a/configs/logger/neptune.yaml b/configs/logger/neptune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8233c140018ecce6ab62971beed269991d31c89b --- /dev/null +++ b/configs/logger/neptune.yaml @@ -0,0 +1,9 @@ +# https://neptune.ai + +neptune: + _target_: lightning.pytorch.loggers.neptune.NeptuneLogger + api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable + project: username/lightning-hydra-template + # name: "" + log_model_checkpoints: True + prefix: "" diff --git a/configs/logger/tensorboard.yaml b/configs/logger/tensorboard.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2bd31f6d8ba68d1f5c36a804885d5b9f9c1a9302 --- /dev/null +++ b/configs/logger/tensorboard.yaml @@ -0,0 +1,10 @@ +# https://www.tensorflow.org/tensorboard/ + +tensorboard: + _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger + save_dir: "${paths.output_dir}/tensorboard/" + name: null + log_graph: False + default_hp_metric: True + prefix: "" + # version: "" diff --git a/configs/logger/wandb.yaml b/configs/logger/wandb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..69ec44e129b56981e37670ed13cdc0a56f84aef9 --- /dev/null +++ b/configs/logger/wandb.yaml @@ -0,0 +1,16 @@ +# https://wandb.ai + +wandb: + _target_: lightning.pytorch.loggers.wandb.WandbLogger + # name: "" # name of the run (normally generated by wandb) + save_dir: "${paths.output_dir}" + offline: False + id: null # pass correct id to resume experiment! + anonymous: null # enable anonymous logging + project: "dnabind" + log_model: False # upload lightning ckpts + prefix: "" # a string to put at the beginning of metric keys + # entity: "" # set to name of your wandb team + group: "" + tags: [] + job_type: "" diff --git a/configs/model/classifier.yaml b/configs/model/classifier.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6b8b40fc30b36d8e0c875cea770fab4d8c70e939 --- /dev/null +++ b/configs/model/classifier.yaml @@ -0,0 +1,9 @@ +_target_: dpacman.classifier.model.BindPredictor + +lr: 1e-4 +alpha: 20 +gamma: 20 +weight_decay: 0.01 + +glm_input_dim: 1029 +compressed_dim: 1029 \ No newline at end of file diff --git a/configs/models/pooling/truncatedsvd.yaml b/configs/model/pooling/truncatedsvd.yaml similarity index 100% rename from configs/models/pooling/truncatedsvd.yaml rename to configs/model/pooling/truncatedsvd.yaml diff --git a/configs/models/classifier.yaml b/configs/models/classifier.yaml deleted file mode 100644 index 98b47fccafa9848d56dd87d65cec1f1b0482ede7..0000000000000000000000000000000000000000 --- a/configs/models/classifier.yaml +++ /dev/null @@ -1,11 +0,0 @@ -name: classifier -type: train - -params: - epochs: 10 - batch_size: 32 - lr: 1e-4 - seed: 42 - -out_dir: null -pair_list: null \ No newline at end of file diff --git a/configs/preprocess.yaml b/configs/preprocess.yaml index 5ea23372a8939b2845c164f0b6de9dca2bee6328..bfccdd3ff4c808178c4eb1bfacdf0f7046962b55 100644 --- a/configs/preprocess.yaml +++ b/configs/preprocess.yaml @@ -6,4 +6,4 @@ defaults: - hydra: default # ← tells Hydra to use the logging/output config - data_task: download/genome -task_name: preprocess/${data_task.type} +task_name: preprocess/${data_task.task_type} diff --git a/configs/train.yaml b/configs/train.yaml index 1246e38ecaec2908a633bb7b8ef57baba621b98a..b7ad9449801b233958dbad3643e8a5c2f744c074 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -2,7 +2,42 @@ defaults: - _self_ - paths: default - hydra: default # ← tells Hydra to use the logging/output config + - data_module: pair + - model: classifier - trainer: gpu - - data_task: model/classifier + - extras: default + - logger: wandb + - callbacks: default -task_name: train/${data_task.type} \ No newline at end of file + # experiment configs allow for version control of specific hyperparameters + # e.g. best hyperparameters for given model and datamodule + - experiment: null + + # config for hyperparameter optimization + - hparams_search: null + + # debugging config (enable through command line, e.g. `python train.py debug=default) + - debug: null + +task_name: train/${model} + +# tags to help you identify your experiments +# you can overwrite this in experiment configs +# overwrite from command line with `python train.py tags="[first_tag, second_tag]"` +tags: ["dev"] + +# set False to skip model training +train: True + +# evaluate on test set, using best model weights achieved during training +# lightning chooses best weights based on the metric specified in checkpoint callback +test: True + +# simply provide checkpoint path to resume training +ckpt_path: null + +# seed for random number generators in pytorch, numpy and python.random +seed: 42 + +trainer: + max_epochs: 20 \ No newline at end of file diff --git a/configs/trainer/cpu.yaml b/configs/trainer/cpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7d6767e60c956567555980654f15e7bb673a41f --- /dev/null +++ b/configs/trainer/cpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: cpu +devices: 1 diff --git a/configs/trainer/ddp.yaml b/configs/trainer/ddp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ab8f89004c399a33440f014fa27e040d4e952bc2 --- /dev/null +++ b/configs/trainer/ddp.yaml @@ -0,0 +1,9 @@ +defaults: + - default + +strategy: ddp + +accelerator: gpu +devices: 4 +num_nodes: 1 +sync_batchnorm: True diff --git a/configs/trainer/ddp_sim.yaml b/configs/trainer/ddp_sim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8404419e5c295654967d0dfb73a7366e75be2f1f --- /dev/null +++ b/configs/trainer/ddp_sim.yaml @@ -0,0 +1,7 @@ +defaults: + - default + +# simulate DDP on CPU, useful for debugging +accelerator: cpu +devices: 2 +strategy: ddp_spawn diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..50905e7fdf158999e7c726edfff1a4dc16d548da --- /dev/null +++ b/configs/trainer/default.yaml @@ -0,0 +1,19 @@ +_target_: lightning.pytorch.trainer.Trainer + +default_root_dir: ${paths.output_dir} + +min_epochs: 1 # prevents early stopping +max_epochs: 10 + +accelerator: cpu +devices: 1 + +# mixed precision for extra speed-up +# precision: 16 + +# perform a validation loop every N training epochs +check_val_every_n_epoch: 1 + +# set True to to ensure deterministic results +# makes training slower but gives more reproducibility than just setting seeds +deterministic: False diff --git a/configs/trainer/gpu.yaml b/configs/trainer/gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b2389510a90f5f0161cff6ccfcb4a96097ddf9a1 --- /dev/null +++ b/configs/trainer/gpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: gpu +devices: 1 diff --git a/configs/trainer/mps.yaml b/configs/trainer/mps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1ecf6d5cc3a34ca127c5510f4a18e989561e38e4 --- /dev/null +++ b/configs/trainer/mps.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: mps +devices: 1 diff --git a/dpacman/classifier/loss.py b/dpacman/classifier/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4c624b4b1dfbd585fec3eac4e8a0bf187b479f76 --- /dev/null +++ b/dpacman/classifier/loss.py @@ -0,0 +1,58 @@ +""" +Define loss functions needed for training the model +""" + +import torch +from torch.nn import functional as F + + +def bce_loss_masked(logits, targets, nonpeak_mask, pos_weight=None): + """ + Compute the masked Binary Cross Entropy, only on certain positions. + We will only compute BCE on positions whre nonpeak_mask == 1.0; the mask represents non-peak positions + """ + loss = F.binary_cross_entropy_with_logits( + logits, targets, reduction="none", pos_weight=pos_weight + ) + denom = nonpeak_mask.sum().clamp_min(1.0) + return (loss * nonpeak_mask).sum() / denom + + +def mse_peaks_only(logits, targets, peak_mask, eps=1e-8): + """ + Calculate MSE on peaks only. + """ + probs = torch.sigmoid(logits) + mse_peaks = F.mse_loss(probs * peak_mask, targets * peak_mask, reduction="sum") / ( + peak_mask.sum() + eps + ) + return mse_peaks + + +def calculate_loss(logits, targets, eps=1e-8, alpha=1.0, gamma=1.0): + """ + Combine masked-BCE + global-MSE to get a loss vlaue + """ + # Calculate peak and non-peak masks. + # Anything outside a peak will have a label equal to 0. + nonpeak_mask = (targets == 0).float() + peak_mask = (targets > 0).float() + + bce_nonpeak = bce_loss_masked(logits, targets, nonpeak_mask) + mse_peak = mse_peaks_only(logits, targets, peak_mask, eps=eps) + + loss = alpha * bce_nonpeak + gamma * mse_peak + + return loss + + +def accuracy_percentage(logits, targets, peak_thresh=0.5): + """ + Compute accuracy in predicting high-confidence peaks (probability > 0.5) + """ + probs = torch.sigmoid(logits) + preds_bin = (probs >= 0.5).float() + labels = (targets >= peak_thresh).float() + correct = (preds_bin == labels).float().sum() + total = torch.numel(labels) + return (correct / max(1, total)).item() * 100.0 diff --git a/dpacman/classifier/model.py b/dpacman/classifier/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0549a464e2d3add5e481a562b3aa54e5a7553064 --- /dev/null +++ b/dpacman/classifier/model.py @@ -0,0 +1,258 @@ +""" +Lightning Module for the binding model. +""" + +import torch +from torch import nn +from lightning import LightningModule +from dpacman.utils.models import set_seed +from .loss import calculate_loss + +set_seed() + + +class LocalCNN(nn.Module): + def __init__(self, dim: int = 256, kernel_size: int = 3): + super().__init__() + padding = kernel_size // 2 + self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding) + self.act = nn.GELU() + self.ln = nn.LayerNorm(dim) + + def forward(self, x: torch.Tensor): + # x: (batch, L, dim) + out = self.conv(x.transpose(1, 2)) # → (batch, dim, L) + out = self.act(out) + out = out.transpose(1, 2) # → (batch, L, dim) + return self.ln(out + x) # residual + + +class CrossModalBlock(nn.Module): + def __init__(self, dim: int = 256, heads: int = 8): + super().__init__() + # self-attention for both sides + self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True) + self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True) + self.ln_b1 = nn.LayerNorm(dim) + self.ln_g1 = nn.LayerNorm(dim) + + self.ffn_b = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim) + ) + self.ffn_g = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim) + ) + self.ln_b2 = nn.LayerNorm(dim) + self.ln_g2 = nn.LayerNorm(dim) + + # cross attention (binder queries, glm keys/values) + # so the NDA path is updated by the transcriptoin factors + self.cross_attn = nn.MultiheadAttention(dim, heads, batch_first=True) + self.ln_c1 = nn.LayerNorm(dim) + self.ffn_c = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim) + ) + self.ln_c2 = nn.LayerNorm(dim) + + def forward(self, binder: torch.Tensor, glm: torch.Tensor): + """ + binder: (batch, Lb, dim) + glm: (batch, Lg, dim) -- has passed through its local CNN beforehand + returns: updated binder representation (batch, Lb, dim) + """ + # binder: self-attn + ffn + b = binder + b_sa, _ = self.sa_binder(b, b, b) + b = self.ln_b1(b + b_sa) + b_ff = self.ffn_b(b) + b = self.ln_b2(b + b_ff) + + # glm: self-attn + ffn + g = glm + g_sa, _ = self.sa_glm(g, g, g) + g = self.ln_g1(g + g_sa) + g_ff = self.ffn_g(g) + g = self.ln_g2(g + g_ff) + + # cross-attention: glm queries binder and glm embeddings are updated + g_to_b_ca, _ = self.cross_attn(g, b, b) + g = self.ln_c1(g + g_to_b_ca) + g_ff = self.ffn_c(g) + g = self.ln_c2(g + g_ff) + return g # (batch, Lb, dim) + + +class DimCompressor(nn.Module): + """ + Learnable per-token compressor: maps any in_dim >= out_dim to out_dim (default 256). + If in_dim == out_dim, behaves as identity. + """ + + def __init__(self, in_dim: int, out_dim: int = 256): + super().__init__() + if in_dim == out_dim: + self.net = nn.Identity() + else: + hidden = max(out_dim * 2, (in_dim + out_dim) // 2) + self.net = nn.Sequential( + nn.LayerNorm(in_dim), + nn.Linear(in_dim, hidden), + nn.GELU(), + nn.Linear(hidden, out_dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (B, L, in_dim) + return self.net(x) + + +class BindPredictor(LightningModule): + def __init__( + self, + # input_dim: int = 256, # OLD: single input dim + binder_input_dim: int = 1280, # NEW: TF (binder) original dim (e.g., 1280) + glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256) + compressed_dim: int = 256, # NEW: learnable compressed dim + hidden_dim: int = 256, + heads: int = 8, + num_layers: int = 4, + lr: float = 1e-4, + alpha: float = 20, + gamma: float = 20, + use_local_cnn_on_glm: bool = True, + weight_decay: float = 0.01, + ): + # Init + super(BindPredictor, self).__init__() + self.save_hyperparameters() + + # Learnable compressor for binder -> 256, then project to hidden + self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim) + self.proj_binder = nn.Linear(compressed_dim, hidden_dim) + + # GLM side stays 256 -> hidden + self.proj_glm = nn.Linear(glm_input_dim, hidden_dim) + + self.use_local_cnn = use_local_cnn_on_glm + self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity() + + self.layers = nn.ModuleList( + [CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)] + ) + + self.ln_out = nn.LayerNorm(hidden_dim) + # self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities + self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP) + + def forward(self, binder_emb, glm_emb): + """ + binder_emb: (B, Lb, binder_input_dim) + glm_emb: (B, Lg, glm_input_dim) + Returns per-nucleotide logits for the GLM sequence: (B, Lg) + """ + # Binder: learnable compression → 256 → hidden + b = self.binder_compress(binder_emb) # (B, Lb, 256) + b = self.proj_binder(b) # (B, Lb, hidden_dim) + + # GLM: project → hidden, add local CNN context + g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim) + if self.use_local_cnn: + g = self.local_cnn(g) + + # Cross-modal blocks: update binder states using GLM + for layer in self.layers: + g = layer(b, g) # (B, Lb, hidden_dim) + + # Predict per-nucleotide logits on the GLM tokens: + # return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head) + return self.head(g).squeeze( + -1 + ) # NEW: logits (apply sigmoid only in loss/metrics) + + # ----- Lightning hooks ----- + def training_step(self, batch, batch_idx): + """ + Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator. + Colator returns a dictionary with: + "binder_emb" # [B, Lb_max, Db] + "binder_mask" # [B, Lb_max] + "glm_emb" # [B, Lg_max, Dg] + "glm_mask" # [B, Lg_max] + "labels" # [B, Lg_max] + "ID" + "tr_sequence" + "dna_sequence" + } + """ + logits = self.forward(batch["binder_emb"], batch["glm_emb"]) + loss = calculate_loss( + logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma + ) + self.log( + "train/loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + batch_size=logits.size(0), + ) + return loss + + def validation_step(self, batch, batch_idx): + logits = self.forward(batch["binder_emb"], batch["glm_emb"]) + loss = calculate_loss( + logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma + ) + self.log( + "val/loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=logits.size(0), + ) + return loss + + def test_step(self, batch, batch_idx): + logits = self.forward(batch["binder_emb"], batch["glm_emb"]) + loss = calculate_loss( + logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma + ) + self.log( + "test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0) + ) + return loss + + def on_train_epoch_end(self): + if False: + if self.train_auc.compute() is not None: + self.log("train/auroc", self.train_auc.compute(), prog_bar=True) + self.train_auc.reset() + + def on_validation_epoch_end(self): + if False: + if self.val_auc.compute() is not None: + self.log("val/auroc", self.val_auc.compute(), prog_bar=True) + self.val_auc.reset() + + def on_test_epoch_end(self): + if False: + if self.test_auc.compute() is not None: + self.log("test/auroc", self.test_auc.compute(), prog_bar=True) + self.test_auc.reset() + + def configure_optimizers(self): + # AdamW + cosine as a sensible default + opt = torch.optim.AdamW( + self.parameters(), + lr=self.hparams.lr, + weight_decay=self.hparams.weight_decay, + ) + # Scheduler optional—comment out if you prefer fixed LR + sch = torch.optim.lr_scheduler.CosineAnnealingLR( + opt, T_max=max(self.trainer.max_epochs, 1) + ) + return { + "optimizer": opt, + "lr_scheduler": {"scheduler": sch, "interval": "epoch"}, + } diff --git a/dpacman/classifier/model/clustering_data.py b/dpacman/classifier/model/clustering_data.py deleted file mode 100644 index a8276866cc86ed600841a5c679efc797418a89d3..0000000000000000000000000000000000000000 --- a/dpacman/classifier/model/clustering_data.py +++ /dev/null @@ -1,383 +0,0 @@ -#!/usr/bin/env python3 -import argparse -import numpy as np -import pandas as pd -from pathlib import Path -import random -import sys -import subprocess -from collections import defaultdict - -# ───────────────────────────────────────────────────────────────────────── -# Original helpers (kept; some lightly edited/commented where needed) -# ───────────────────────────────────────────────────────────────────────── - -def read_ids_file(p): - p = Path(p) - if not p.exists(): - raise FileNotFoundError(f"IDs file not found: {p}") - return [line.strip() for line in p.open() if line.strip()] - -def split_embeddings(emb_path, ids_path, out_dir, prefix): - out_dir = Path(out_dir) - out_dir.mkdir(parents=True, exist_ok=True) - - if not Path(emb_path).exists(): - raise FileNotFoundError(f"Embedding file not found: {emb_path}") - if not Path(ids_path).exists(): - raise FileNotFoundError(f"IDs file not found: {ids_path}") - - if emb_path.endswith(".npz"): - data = np.load(emb_path, allow_pickle=True) - if "embeddings" in data: - emb = data["embeddings"] - else: - raise ValueError(f"{emb_path} missing 'embeddings' key") - else: - emb = np.load(emb_path) - - ids = read_ids_file(ids_path) - if len(ids) != emb.shape[0]: - print(f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", file=sys.stderr) - - mapping = {} - for i, ident in enumerate(ids): - if i >= emb.shape[0]: - print(f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr) - continue - arr = emb[i] - out_file = out_dir / f"{prefix}_{ident}.npy" - np.save(out_file, arr) - mapping[ident] = str(out_file) - return mapping - -def extract_symbol_from_tf_id(full_id: str) -> str: - """ - Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN', - return the gene symbol uppercase (e.g., 'ZBTB5'). - """ - if "|" in full_id: - try: - # format sp|Accession|SYMBOL_HUMAN - genepart = full_id.split("|")[2] - except IndexError: - genepart = full_id - else: - genepart = full_id - symbol = genepart.split("_")[0] - return symbol.upper() - -def build_tf_symbol_map(tf_map): - """ - Build mapping gene_symbol -> list of embedding paths. - """ - symbol_map = {} - for full_id, path in tf_map.items(): - symbol = extract_symbol_from_tf_id(full_id) - symbol_map.setdefault(symbol, []).append(path) - return symbol_map - -def tf_key_from_path(path: str) -> str: - """ - Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'. - """ - stem = Path(path).stem # e.g., tf_sp|O15062|ZBTB5_HUMAN - # remove leading prefix if present (tf_) - if "_" in stem: - _, rest = stem.split("_", 1) - else: - rest = stem - return extract_symbol_from_tf_id(rest) - -def dna_key_from_path(path: str) -> str: - """ - Given .../dna_peak42.npy -> 'peak42' - """ - stem = Path(path).stem - if "_" in stem: - _, rest = stem.split("_", 1) - else: - rest = stem - return rest - -# ───────────────────────────────────────────────────────────────────────── -# New helpers for MMseqs clustering & cluster-level splitting -# ───────────────────────────────────────────────────────────────────────── - -def write_dna_fasta(df: pd.DataFrame, out_fasta: Path) -> None: - """ - Write unique DNA sequences to FASTA using dna_id as header. - Requires df with columns: dna_id, dna_sequence - """ - uniq = df[["dna_id", "dna_sequence"]].drop_duplicates() - with open(out_fasta, "w") as f: - for _, row in uniq.iterrows(): - did = row["dna_id"] - seq = str(row["dna_sequence"]).upper().replace(" ", "").replace("\n", "") - f.write(f">{did}\n{seq}\n") - -def run_mmseqs_easy_cluster( - mmseqs_bin: str, - fasta: Path, - out_prefix: Path, - tmp_dir: Path, - min_seq_id: float, - coverage: float, - cov_mode: int, -) -> Path: - """ - Runs mmseqs easy-cluster on nucleotide sequences. - Returns the path to a clusters TSV file (creating it if the default one isn't present). - """ - tmp_dir.mkdir(parents=True, exist_ok=True) - out_prefix.parent.mkdir(parents=True, exist_ok=True) - - cmd = [ - mmseqs_bin, "easy-cluster", - str(fasta), str(out_prefix), str(tmp_dir), - "--min-seq-id", str(min_seq_id), - "-c", str(coverage), - "--cov-mode", str(cov_mode), - # You can add performance flags here if needed, e.g.: - # "--threads", "8" - ] - print("[i] Running:", " ".join(cmd), flush=True) - subprocess.run(cmd, check=True) - - # MMseqs easy-cluster typically writes _cluster.tsv - default_tsv = Path(str(out_prefix) + "_cluster.tsv") - if default_tsv.exists(): - print(f"[i] Found cluster TSV: {default_tsv}") - return default_tsv - - # Fallback: try createtsv if default is missing - # This requires the internal DBs. easy-cluster creates DBs alongside out_prefix. - # We'll try to locate them and emit a TSV. - in_db = Path(str(out_prefix) + "_query") - cl_db = Path(str(out_prefix) + "_cluster") - out_tsv = Path(str(out_prefix) + "_fallback_cluster.tsv") - if in_db.exists() and cl_db.exists(): - cmd2 = [mmseqs_bin, "createtsv", str(in_db), str(in_db), str(cl_db), str(out_tsv)] - print("[i] Creating TSV via createtsv:", " ".join(cmd2), flush=True) - subprocess.run(cmd2, check=True) - if out_tsv.exists(): - return out_tsv - - raise FileNotFoundError("Could not locate clusters TSV from mmseqs. " - "Expected {default_tsv} or createtsv fallback.") - -def parse_mmseqs_clusters(tsv_path: Path) -> dict: - """ - Parse MMseqs cluster TSV (rep \t member). Returns dna_id -> cluster_rep_id - """ - mapping = {} - with open(tsv_path) as f: - for line in f: - parts = line.rstrip("\n").split("\t") - if len(parts) < 2: - continue - rep, member = parts[0], parts[1] - mapping[member] = rep - # Some TSVs include rep->rep; if not, ensure rep is mapped to itself: - if rep not in mapping: - mapping[rep] = rep - return mapping - -def assign_clusters_to_splits(cluster_rep_to_members: dict, - val_frac: float, - test_frac: float, - seed: int = 42): - """ - cluster_rep_to_members: dict[rep] = [members...] - Returns: dict with keys 'train','val','test' mapping to sets of dna_id. - Ensures all members of a cluster go to the same split. - """ - rng = random.Random(seed) - reps = list(cluster_rep_to_members.keys()) - rng.shuffle(reps) - - # Greedy-ish fill by total member counts to match desired fractions. - total = sum(len(cluster_rep_to_members[r]) for r in reps) - target_val = int(round(total * val_frac)) - target_test = int(round(total * test_frac)) - cur_val = cur_test = 0 - - val_ids, test_ids, train_ids = set(), set(), set() - for rep in reps: - members = cluster_rep_to_members[rep] - c = len(members) - # Fill val first, then test, then train - if cur_val + c <= target_val: - val_ids.update(members); cur_val += c - elif cur_test + c <= target_test: - test_ids.update(members); cur_test += c - else: - train_ids.update(members) - - return {"train": train_ids, "val": val_ids, "test": test_ids} - -# ───────────────────────────────────────────────────────────────────────── -# Main -# ───────────────────────────────────────────────────────────────────────── - -def main(): - parser = argparse.ArgumentParser( - description="Build TF-DNA pair lists with MMseqs clustering on DNA to prevent split leakage." - ) - parser.add_argument("--final_csv", required=True, help="final.csv with TF_id and dna_sequence") - parser.add_argument("--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)") - parser.add_argument("--dna_ids", required=True, help="IDs file for DNA embeddings (peak*.ids)") - parser.add_argument("--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)") - parser.add_argument("--tf_ids", required=True, help="IDs file for TF embeddings (sp|... ids)") - parser.add_argument("--out_dir", required=True, help="Output directory") - parser.add_argument("--seed", type=int, default=42) - - # NEW: MMseqs options & split fractions - parser.add_argument("--mmseqs_bin", default="mmseqs", help="Path to mmseqs binary") - parser.add_argument("--min_seq_id", type=float, default=0.9, help="MMseqs --min-seq-id") - parser.add_argument("--cov", type=float, default=0.8, help="MMseqs -c coverage fraction") - parser.add_argument("--cov_mode", type=int, default=1, help="MMseqs --cov-mode (1 = coverage of target)") - parser.add_argument("--val_frac", type=float, default=0.10) - parser.add_argument("--test_frac", type=float, default=0.10) - parser.add_argument("--tmp_dir", default=None, help="MMseqs tmp dir (defaults to out_dir/tmp)") - args = parser.parse_args() - - random.seed(args.seed) - out_dir = Path(args.out_dir) - out_dir.mkdir(parents=True, exist_ok=True) - - # Load final.csv - df = pd.read_csv(args.final_csv, dtype=str) - if "TF_id" not in df.columns or "dna_sequence" not in df.columns: - raise RuntimeError("final.csv must have columns TF_id and dna_sequence") - - # Assign dna_id (unique per dna_sequence) - unique_seqs = df["dna_sequence"].drop_duplicates().tolist() - seq_to_id = {seq: f"peak{i}" for i, seq in enumerate(unique_seqs)} - df["dna_id"] = df["dna_sequence"].map(seq_to_id) - enriched_csv = out_dir / "final_with_dna_id.csv" - df.to_csv(enriched_csv, index=False) - print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}") - - # Split embeddings into per-item files (unchanged) - print(f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}") - dna_map = split_embeddings(args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna") - print(f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})") - print(f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}") - tf_map = split_embeddings(args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf") - print(f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})") - - # Build gene-symbol normalized map - tf_symbol_map = build_tf_symbol_map(tf_map) - print(f"[i] TF symbol map keys (sample): {list(tf_symbol_map.keys())[:30]}") - - # Diagnostic overlaps - norm_tf_in_final = set(t.split("_seq")[0].upper() for t in df["TF_id"].unique()) - available_tf_symbols = set(tf_symbol_map.keys()) - intersect_tf = norm_tf_in_final & available_tf_symbols - print(f"[i] Unique normalized TF symbols in final.csv: {len(norm_tf_in_final)}") - print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}") - print(f"[i] Intersection count: {len(intersect_tf)}") - if len(intersect_tf) == 0: - print("[ERROR] No overlap between normalized TF_id and TF embedding symbols.", file=sys.stderr) - print("Sample normalized TFs from final.csv:", sorted(list(norm_tf_in_final))[:30], file=sys.stderr) - print("Sample available TF symbols:", sorted(list(available_tf_symbols))[:30], file=sys.stderr) - sys.exit(1) - - dna_ids_final = set(df["dna_id"].unique()) - available_dna_ids = set(dna_map.keys()) - intersect_dna = dna_ids_final & available_dna_ids - print(f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}") - if len(intersect_dna) == 0: - print("[ERROR] No overlap on DNA ids.", file=sys.stderr) - sys.exit(1) - - # ── NEW: MMseqs clustering on DNA sequences ─────────────────────────── - fasta_path = out_dir / "dna_unique.fasta" - write_dna_fasta(df, fasta_path) - print(f"[i] Wrote FASTA with {df['dna_id'].nunique()} unique sequences → {fasta_path}") - - tmp_dir = Path(args.tmp_dir) if args.tmp_dir else (out_dir / "mmseqs_tmp") - cluster_prefix = out_dir / "mmseqs_dna_clusters" - clusters_tsv = run_mmseqs_easy_cluster( - mmseqs_bin=args.mmseqs_bin, - fasta=fasta_path, - out_prefix=cluster_prefix, - tmp_dir=tmp_dir, - min_seq_id=args.min_seq_id, - coverage=args.cov, - cov_mode=args.cov_mode, - ) - - # Parse clusters - member_to_rep = parse_mmseqs_clusters(clusters_tsv) # dna_id -> rep_id - # Build rep -> members list - rep_to_members = defaultdict(list) - for member, rep in member_to_rep.items(): - rep_to_members[rep].append(member) - - print(f"[i] Parsed {len(rep_to_members)} clusters from {clusters_tsv}") - clusters_table = [] - for rep, members in rep_to_members.items(): - for m in members: - clusters_table.append((m, rep)) - clusters_df = pd.DataFrame(clusters_table, columns=["dna_id", "cluster_id"]) - clusters_df.to_csv(out_dir / "clusters.tsv", sep="\t", index=False) - print(f"[i] Wrote clusters mapping → {out_dir / 'clusters.tsv'}") - - # Attach cluster_id back to final df - df = df.merge(clusters_df, on="dna_id", how="left") - df.to_csv(out_dir / "final_with_dna_id_and_cluster.csv", index=False) - print(f"[i] Wrote {out_dir / 'final_with_dna_id_and_cluster.csv'}") - - # Assign entire clusters to splits - splits = assign_clusters_to_splits(rep_to_members, - val_frac=args.val_frac, - test_frac=args.test_frac, - seed=args.seed) - for k in ["train", "val", "test"]: - print(f"[i] {k}: {len(splits[k])} dna_ids") - - # ── Build positive pairs only, per split (NO negatives) ─────────────── - positives_by_split = {"train": [], "val": [], "test": []} - # Build a quick dna_id -> embedding path map - dnaid_to_path = {did: path for did, path in dna_map.items()} - - pos_count = 0 - for _, row in df.iterrows(): - tf_raw = row["TF_id"] - tf_symbol = tf_raw.split("_seq")[0].upper() - dnaid = row["dna_id"] - if (tf_symbol not in tf_symbol_map) or (dnaid not in dnaid_to_path): - continue - tf_embedding_path = tf_symbol_map[tf_symbol][0] # first embedding per symbol - - # decide split by dna_id cluster assignment - if dnaid in splits["train"]: - positives_by_split["train"].append((tf_embedding_path, dnaid_to_path[dnaid], 1)) - elif dnaid in splits["val"]: - positives_by_split["val"].append((tf_embedding_path, dnaid_to_path[dnaid], 1)) - elif dnaid in splits["test"]: - positives_by_split["test"].append((tf_embedding_path, dnaid_to_path[dnaid], 1)) - pos_count += 1 - - print(f"[i] Constructed positives across splits (rows in final.csv iterated: {len(df)})") - for k in ["train", "val", "test"]: - print(f"[i] positives[{k}] = {len(positives_by_split[k])}") - - # # OLD: negatives (kept commented) - # negatives = [] - # print(f"[i] Sampled {len(negatives)} negatives (neg_per_positive not used)") - - # Emit split-specific pair lists - for split in ["train", "val", "test"]: - out_tsv = out_dir / f"pair_list_{split}.tsv" - with open(out_tsv, "w") as f: - for binder_path, glm_path, label in positives_by_split[split]: # + negatives if you add later - f.write(f"{binder_path}\t{glm_path}\t{label}\n") - print(f"[i] Wrote {len(positives_by_split[split])} examples to {out_tsv}") - - print("✅ Done. Cluster-aware splits ready.") - -if __name__ == "__main__": - main() diff --git a/dpacman/classifier/model/compress_embeddings.py b/dpacman/classifier/model/compress_embeddings.py deleted file mode 100644 index 248f0143ff2430b3c244b856cd853895c51d3552..0000000000000000000000000000000000000000 --- a/dpacman/classifier/model/compress_embeddings.py +++ /dev/null @@ -1,54 +0,0 @@ -# compress_embeddings.py -# USAGE: python compress_embeddings.py --input_glob "/path/to/esm_embeddings/*.npy" --output_dir "/path/to/compressed_embeddings" --esm_dim 1280 --out_dim 256 -# -------------- -import os -import glob -import numpy as np -import torch -from torch import nn - -class EmbeddingCompressor(nn.Module): - def __init__(self, input_dim: int = 1280, output_dim: int = 256): - super().__init__() - self.fc = nn.Linear(input_dim, output_dim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - x: (batch, L, input_dim) or (L, input_dim) - returns: (batch, output_dim) or (output_dim,) - """ - if x.dim() == 2: - # single example: mean over tokens - x = x.mean(dim=0, keepdim=True) # → (1, input_dim) - else: - # batch: mean over tokens - x = x.mean(dim=1) # → (batch, input_dim) - return self.fc(x) # → (batch, output_dim) - -def compress_file(in_path: str, out_path: str, model: EmbeddingCompressor): - arr = np.load(in_path) # shape (L, D) or (batch, L, D) - tensor = torch.from_numpy(arr).float() - with torch.no_grad(): - compressed = model(tensor) # → (batch, 256) - out = compressed.cpu().numpy() - np.save(out_path, out) - print(f"Saved {out_path}") - -if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser(description="Compress ESM embeddings to 256­d") - parser.add_argument("--input_glob", type=str, required=True, - help="Glob for your .npy ESM embeddings (e.g. data/esm_*.npy)") - parser.add_argument("--output_dir", type=str, required=True) - parser.add_argument("--esm_dim", type=int, default=1280) - parser.add_argument("--out_dim", type=int, default=256) - args = parser.parse_args() - - os.makedirs(args.output_dir, exist_ok=True) - compressor = EmbeddingCompressor(args.esm_dim, args.out_dim) - compressor.eval() - - for fn in glob.glob(args.input_glob): - base = os.path.basename(fn).replace(".npy", "_256.npy") - out_path = os.path.join(args.output_dir, base) - compress_file(fn, out_path, compressor) diff --git a/dpacman/classifier/model/compute_embeddings.py b/dpacman/classifier/model/compute_embeddings.py deleted file mode 100644 index de6b8447d33326eac8f19ba2ee651e9ffd72cb63..0000000000000000000000000000000000000000 --- a/dpacman/classifier/model/compute_embeddings.py +++ /dev/null @@ -1,560 +0,0 @@ -""" -Plug-and-play embedding extraction for: - • Chromosome sequences (from raw UCSC JSON) - • TF sequences (transcription_factors.fasta) - -Usage example (DNA + protein in one go): - module load miniconda/24.7.1 - conda activate dpacman - python dpacman/data/compute_embeddings.py \ - --genome-json-dir ../data_files/raw/genomes/hg38 \ - --tf-fasta ../data_files/processed/tfclust/hg38_tf/transcription_factors.fasta \ - --chrom-model caduceus \ - --tf-model esm-dbp \ - --out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \ - --device cuda -""" -import os -import re -import argparse -import json -import numpy as np -from pathlib import Path -import torch -from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, pipeline -import esm -from Bio import SeqIO -import time -import pandas as pd -from tqdm.auto import tqdm -import logging, math - -# ---- model wrappers ---- - -class CaduceusEmbedder: - def __init__(self, device, chunk_size=131_072, overlap=0): - """ - device: 'cpu' or 'cuda' - chunk_size: max bases (and thus tokens) to send in one forward pass - overlap: how many bases each window overlaps the previous; 0 = no overlap - """ - model_name = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16" - self.tokenizer = AutoTokenizer.from_pretrained( - model_name, trust_remote_code=True - ) - self.model = AutoModel.from_pretrained( - model_name, trust_remote_code=True - ).to(device).eval() - self.device = device - self.chunk_size = chunk_size - self.step = chunk_size - overlap - - def embed(self, seqs): - """ - seqs: List[str] of DNA sequences (each <= chunk_size for this test) - returns: np.ndarray of shape (N, L, D), raw per‐token embeddings - """ - # outputs = [] - # for seq in seqs: - # # --- new: raw per‐token embeddings in one shot --- - # toks = self.tokenizer( - # seq, - # return_tensors="pt", - # padding=False, - # truncation=True, - # max_length=self.chunk_size - # ).to(self.device) - # with torch.no_grad(): - # out = self.model(**toks).last_hidden_state # (1, L, D) - # outputs.append(out.cpu().numpy()[0]) # (L, D) - - # return np.stack(outputs, axis=0) # (N, L, D) - outputs = [] - for seq in tqdm(seqs, total=len(seqs), desc="DNA: Caduceus", dynamic_ncols=True): - toks = self.tokenizer( - seq, - return_tensors="pt", - padding=False, - truncation=True, - max_length=self.chunk_size - ).to(self.device) - with torch.no_grad(): - out = self.model(**toks).last_hidden_state # (1, L, D) - outputs.append(out.cpu().numpy()[0]) # (L, D) - return outputs # list of variable-length (L_i, D) arrays - - - def benchmark(self, lengths=None): - """ - Time embedding on single-sequence of various lengths. - By default tests [5K,10K,50K,100K,chunk_size]. - """ - tests = lengths or [5_000, 10_000, 50_000, 100_000, self.chunk_size] - print(f"→ Benchmarking Caduceus on device={self.device}") - for sz in tests: - seq = "A" * sz - # Warm-up - _ = self.embed([seq]) - if self.device != "cpu": - torch.cuda.synchronize() - t0 = time.perf_counter() - _ = self.embed([seq]) - if self.device != "cpu": - torch.cuda.synchronize() - t1 = time.perf_counter() - print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms") - -class SegmentNTEmbedder: - def __init__(self, device): - self.tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True) - self.model = AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True).to(device).eval() - self.device = device - - def _adjust_length(self, input_ids): - bs, L = input_ids.shape - excl = L - 1 - remainder = (excl) % 4 - if remainder != 0: - pad_needed = 4 - remainder - pad_tensor = torch.full((bs, pad_needed), self.tokenizer.pad_token_id, dtype=input_ids.dtype, device=input_ids.device) - input_ids = torch.cat([input_ids, pad_tensor], dim=1) - return input_ids - - def embed(self, seqs, batch_size=16): - """ - seqs: List[str] - Returns: np.ndarray of shape (N, D) - """ - all_embeddings = [] - for i in range(0, len(seqs), batch_size): - batch_seqs = seqs[i : i + batch_size] - encoded = self.tokenizer.batch_encode_plus( - batch_seqs, - return_tensors="pt", - padding=True, - truncation=True, - ) - input_ids = encoded["input_ids"].to(self.device) # (B, L) - attention_mask = input_ids != self.tokenizer.pad_token_id - - input_ids = self._adjust_length(input_ids) - attention_mask = (input_ids != self.tokenizer.pad_token_id) - - with torch.no_grad(): - outs = self.model( - input_ids, - attention_mask=attention_mask, - output_hidden_states=True, - return_dict=True, - ) - if hasattr(outs, "hidden_states") and outs.hidden_states is not None: - last_hidden = outs.hidden_states[-1] # (B, L, D) - else: - last_hidden = outs.last_hidden_state # fallback - - # Exclude CLS token if present (assume first token) and pool - pooled = last_hidden[:, 1:, :].mean(dim=1) # (B, D) - all_embeddings.append(pooled.cpu().numpy()) - - # release fragmentation - torch.cuda.empty_cache() - - return np.vstack(all_embeddings) # (N, D) - - -class DNABertEmbedder: - def __init__(self, device): - self.tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True) - self.model = AutoModel.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True).to(device) - self.device = device - - def embed(self, seqs): - embs = [] - for s in seqs: - tokens = self.tokenizer(s, return_tensors="pt", padding=True)["input_ids"].to(self.device) - with torch.no_grad(): - out = self.model(tokens).last_hidden_state.mean(1) - embs.append(out.cpu().numpy()) - return np.vstack(embs) - -class NucleotideTransformerEmbedder: - def __init__(self, device): - # HF “feature-extraction” returns a list of (L, D) arrays for each input - # device: “cpu” or “cuda” - self.pipe = pipeline( - "feature-extraction", - model="InstaDeepAI/nucleotide-transformer-500m-1000g", - device= -1 if device=="cpu" else 0 # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0} - ) - - def embed(self, seqs): - """ - seqs: List[str] of raw DNA sequences - returns: (N, D) array, one D-dim vector per sequence - """ - all_embeddings = self.pipe(seqs, truncation=True, padding=True) - # all_embeddings is a List of shape (L, D) arrays - pooled = [ np.mean(x, axis=0) for x in all_embeddings ] - return np.vstack(pooled) - -# class ESMEmbedder: -# def __init__(self, device): -# self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S() -# self.batch_converter = self.alphabet.get_batch_converter() -# self.model.to(device).eval() -# self.device = device - -# def embed(self, seqs): -# batch = [(str(i), seq) for i, seq in enumerate(seqs)] -# _, _, toks = self.batch_converter(batch) -# toks = toks.to(self.device) -# with torch.no_grad(): -# results = self.model(toks, repr_layers=[33], return_contacts=False) -# reps = results["representations"][33] -# return reps[:, 1:-1].mean(1).cpu().numpy() - - -class ESMEmbedder: - def __init__(self, device, model_name="esm2_t33_650M_UR50D"): - # Try to load the specified ESM-2 model; fallback to esm1b if missing - self.device = device - try: - self.model, self.alphabet = getattr(esm.pretrained, model_name)() - self.is_esm2 = model_name.lower().startswith("esm2") - except AttributeError: - # fallback to ESM-1b - self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S() - self.is_esm2 = False - self.batch_converter = self.alphabet.get_batch_converter() - self.model.to(device).eval() - # determine max length: esm2 models vary; use default 1024 for esm1b - self.max_len = 4096 if self.is_esm2 else 1024 # adjust if your esm2 variant has explicit limit - # for chunking: reserve 2 tokens if model uses BOS/EOS - self.chunk_size = self.max_len - 2 - self.overlap = self.chunk_size // 4 # 25% overlap to smooth boundaries - - def _chunk_sequence(self, seq): - """ - Return list of possibly overlapping chunks of seq, each <= chunk_size. - """ - if len(seq) <= self.chunk_size: - return [seq] - step = self.chunk_size - self.overlap - chunks = [] - for i in range(0, len(seq), step): - chunk = seq[i : i + self.chunk_size] - if not chunk: - break - chunks.append(chunk) - return chunks - - def embed(self, seqs): - """ - seqs: List[str] of protein sequences. - Returns: np.ndarray of shape (N, D) pooled per-sequence embeddings. - """ - all_embeddings = [] - for i, seq in enumerate(seqs): - chunks = self._chunk_sequence(seq) - chunk_vecs = [] - # process chunks in batch if small number, else sequentially - for chunk in chunks: - batch = [(str(i), chunk)] - _, _, toks = self.batch_converter(batch) - toks = toks.to(self.device) - with torch.no_grad(): - results = self.model(toks, repr_layers=[33], return_contacts=False) - reps = results["representations"][33] # (1, L, D) - # remove BOS/EOS if present: take 1:-1 if length permits - if reps.size(1) > 2: - rep = reps[:, 1:-1].mean(1) # (1, D) - else: - rep = reps.mean(1) # fallback - chunk_vecs.append(rep.squeeze(0)) # (D,) - if len(chunk_vecs) == 1: - seq_vec = chunk_vecs[0] - else: - # average chunk vectors - stacked = torch.stack(chunk_vecs, dim=0) # (num_chunks, D) - seq_vec = stacked.mean(0) - all_embeddings.append(seq_vec.cpu().numpy()) - return np.vstack(all_embeddings) # (N, D) - - -# class ESMDBPEmbedder: -# def __init__(self, device): -# base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S() -# model_path = ( -# Path(__file__).resolve().parent.parent -# / "pretrained" / "ESM-DBP" / "ESM-DBP.model" -# ) -# checkpoint = torch.load(model_path, map_location="cpu") -# clean_sd = {} -# for k, v in checkpoint.items(): -# clean_sd[k.replace("module.", "")] = v -# result = base_model.load_state_dict(clean_sd, strict=False) -# if result.missing_keys: -# print(f"[ESMDBP] missing keys: {result.missing_keys}") -# if result.unexpected_keys: -# print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}") - -# self.model = base_model.to(device).eval() -# self.alphabet = alphabet -# self.batch_converter = alphabet.get_batch_converter() -# self.device = device - -# def embed(self, seqs): -# batch = [(str(i), seq) for i, seq in enumerate(seqs)] -# _, _, toks = self.batch_converter(batch) -# toks = toks.to(self.device) -# with torch.no_grad(): -# out = self.model(toks, repr_layers=[33], return_contacts=False) -# reps = out["representations"][33] -# # skip start/end tokens -# return reps[:, 1:-1].mean(1).cpu().numpy() - -class ESMDBPEmbedder: - def __init__(self, device): - base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S() - model_path = ( - Path(__file__).resolve().parent.parent - / "pretrained" / "ESM-DBP" / "ESM-DBP.model" - ) - checkpoint = torch.load(model_path, map_location="cpu") - clean_sd = {} - for k, v in checkpoint.items(): - clean_sd[k.replace("module.", "")] = v - result = base_model.load_state_dict(clean_sd, strict=False) - if result.missing_keys: - print(f"[ESMDBP] missing keys: {result.missing_keys}") - if result.unexpected_keys: - print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}") - - self.model = base_model.to(device).eval() - self.alphabet = alphabet - self.batch_converter = alphabet.get_batch_converter() - self.device = device - self.max_len = 1024 # same limit as esm1b - self.chunk_size = self.max_len - 2 - self.overlap = self.chunk_size // 4 - - def _chunk_sequence(self, seq): - if len(seq) <= self.chunk_size: - return [seq] - step = self.chunk_size - self.overlap - chunks = [] - for i in range(0, len(seq), step): - chunk = seq[i : i + self.chunk_size] - if not chunk: - break - chunks.append(chunk) - return chunks - - def embed(self, seqs): - all_embeddings = [] - for i, seq in enumerate(seqs): - chunks = self._chunk_sequence(seq) - chunk_vecs = [] - for chunk in chunks: - batch = [(str(i), chunk)] - _, _, toks = self.batch_converter(batch) - toks = toks.to(self.device) - with torch.no_grad(): - out = self.model(toks, repr_layers=[33], return_contacts=False) - reps = out["representations"][33] - if reps.size(1) > 2: - rep = reps[:, 1:-1].mean(1) - else: - rep = reps.mean(1) - chunk_vecs.append(rep.squeeze(0)) - if len(chunk_vecs) == 1: - seq_vec = chunk_vecs[0] - else: - stacked = torch.stack(chunk_vecs, dim=0) - seq_vec = stacked.mean(0) - all_embeddings.append(seq_vec.cpu().numpy()) - return np.vstack(all_embeddings) - -class GPNEmbedder: - def __init__(self, device): - model_name = "songlab/gpn-msa-sapiens" - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - self.model = AutoModelForMaskedLM.from_pretrained(model_name) - self.model.to(device) - self.model.eval() - self.device = device - - def embed(self, seqs): - inputs = self.tokenizer( - seqs, - return_tensors="pt", - padding=True, - truncation=True - ).to(self.device) - - with torch.no_grad(): - last_hidden = self.model(**inputs).last_hidden_state - return last_hidden.mean(dim=1).cpu().numpy() - -class ProGenEmbedder: - def __init__(self, device): - model_name = "jinyuan22/ProGen2-base" - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - self.model = AutoModel.from_pretrained(model_name).to(device).eval() - self.device = device - - def embed(self, seqs): - inputs = self.tokenizer( - seqs, - return_tensors="pt", - padding=True, - truncation=True - ).to(self.device) - with torch.no_grad(): - last_hidden = self.model(**inputs).last_hidden_state - return last_hidden.mean(dim=1).cpu().numpy() - -# ---- main pipeline ---- - -def get_embedder(name, device, for_dna=True): - name = name.lower() - if for_dna: - if name=="caduceus": return CaduceusEmbedder(device) - if name=="dnabert": return DNABertEmbedder(device) - if name=="nucleotide": return NucleotideTransformerEmbedder(device) - if name=="gpn": return GPNEmbedder(device) - if name=="segmentnt": return SegmentNTEmbedder(device) - else: - if name in ("esm",): return ESMEmbedder(device) - if name in ("esm-dbp","esm_dbp"): return ESMDBPEmbedder(device) - if name=="progen": return ProGenEmbedder(device) - raise ValueError(f"Unknown model {name} (for_dna={for_dna})") - - -def pad_token_embeddings(list_of_arrays, pad_value=0.0): - """ - list_of_arrays: list of (L_i, D) numpy arrays - Returns: - padded: (N, L_max, D) array - mask: (N, L_max) boolean array where True = real token, False = padding - """ - N = len(list_of_arrays) - D = list_of_arrays[0].shape[1] - L_max = max(arr.shape[0] for arr in list_of_arrays) - padded = np.full((N, L_max, D), pad_value, dtype=list_of_arrays[0].dtype) - mask = np.zeros((N, L_max), dtype=bool) - for i, arr in enumerate(list_of_arrays): - L = arr.shape[0] - padded[i, :L] = arr - mask[i, :L] = True - return padded, mask - -def embed_and_save(seqs, ids, embedder, out_path): - embs = embedder.embed(seqs) - - # Decide whether we got variable-length per-token outputs (list of (L, D)) - is_variable_token = isinstance(embs, (list, tuple)) and len(embs) > 0 and hasattr(embs[0], "shape") and embs[0].ndim == 2 - - if is_variable_token: - # pad to (N, L_max, D) + mask - padded, mask = pad_token_embeddings(embs) - # Save both embeddings and mask together in an .npz for convenience - np.savez_compressed(out_path.with_suffix(".caduceus.npz"), - embeddings=padded, - mask=mask, - ids=np.array(ids, dtype=object)) - else: - # fixed shape output, e.g., pooled (N, D) - array = np.vstack(embs) if isinstance(embs, list) else embs - np.save(out_path, array) - with open(out_path.with_suffix(".ids"), "w") as f: - f.write("\n".join(ids)) - - -if __name__=="__main__": - - p = argparse.ArgumentParser() - #p.add_argument("--peak_fasta", default="binding_peaks_unique.fa", help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs") - p.add_argument("--genome-json-dir", default=None, help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes") - p.add_argument("--skip-dna", action="store_true", help="if set, skip the chromosome embedding step") #if glm embeddings successful but not plm embeddings - p.add_argument("--tf-fasta", required=True, help="input TF FASTA file") - p.add_argument("--chrom-model", default="caduceus") - p.add_argument("--tf-model", default="esm-dbp") - p.add_argument("--out-dir", default="dpacman/model/embeddings") - p.add_argument("--device", default="cpu") - args = p.parse_args() - - os.makedirs(args.out_dir, exist_ok=True) - device = args.device - print(device) - - if not args.skip_dna: - if args.genome_json_dir == None: - dna_df = pd.read_parquet('/home/a03-akrishna/DPACMAN/dpacman/model/remap2022_crm_fimo_output_q_processed.parquet', engine='pyarrow') - #df.to_csv('/home/a03-akrishna/DPACMAN/dpacman/model/remap2022_crm_fimo_output_q_processed.csv', index=False) - peak_seqs = dna_df["dna_sequence"] - peak_ids = dna_df["ID"] - print(f"Embedding {len(peak_seqs)} binding peak sequences from processed remap data", flush=True) - dna_embedder = get_embedder(args.chrom_model, device, for_dna=True) - out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy" - embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks) - - # peak_fasta = Path(args.peak_fasta) - # if peak_fasta.exists(): - # # Load peak sequences from FASTA - # from Bio import SeqIO - - # peak_seqs = [] - # peak_ids = [] - # for rec in SeqIO.parse(peak_fasta, "fasta"): - # peak_ids.append(rec.id) - # peak_seqs.append(str(rec.seq)) - # print(f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}", flush=True) - # dna_embedder = get_embedder(args.chrom_model, device, for_dna=True) - # out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy" - # embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks) - elif args.genome_json_dir: - # Legacy: load full chromosomes from JSONs (chr1–22, X, Y, M) - genome_dir = Path(args.genome_json_dir) - chrom_seqs, chrom_ids = [], [] - primary_pattern = re.compile(r"^hg38_chr(?:[1-9]|1[0-9]|2[0-2]|X|Y|M)\.json$") - for j in sorted(genome_dir.iterdir()): - if not primary_pattern.match(j.name): - continue - data = json.loads(j.read_text()) - seq = data.get("dna") or data.get("sequence") - chrom = data.get("chrom") or j.stem.split("_")[-1] - chrom_seqs.append(seq) - chrom_ids.append(chrom) - cutoff = CaduceusEmbedder(device).chunk_size - long_chroms = [ - (chrom, len(seq)) - for chrom, seq in zip(chrom_ids, chrom_seqs) - if len(seq) > cutoff - ] - if long_chroms: - print("⚠️ Chromosomes exceeding Caduceus max tokens ({}):".format(cutoff)) - for chrom, L in long_chroms: - print(f" {chrom}: {L} bases") - else: - print("All chromosomes ≤ Caduceus limit ({}).".format(cutoff)) - - chrom_embedder = get_embedder(args.chrom_model, device, for_dna=True) - out_chrom = Path(args.out_dir) / f"chrom_{args.chrom_model}.npy" - embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom) - else: - raise ValueError("No input for DNA embedding: provide a peak FASTA (default binding_peaks_unique.fa) or set --genome-json-dir for chromosome JSONs.") - - - #Load TF sequences - tf_seqs, tf_ids = [], [] - for record in SeqIO.parse(args.tf_fasta, "fasta"): - tf_ids.append(record.id) - tf_seqs.append(str(record.seq)) - - # embed and save - tf_embedder = get_embedder(args.tf_model, device, for_dna=False) - out_tf = Path(args.out_dir) / f"tf_{args.tf_model}.npy" - embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf) - - print("Done.") \ No newline at end of file diff --git a/dpacman/classifier/model/extract_tf_symbols.py b/dpacman/classifier/model/extract_tf_symbols.py deleted file mode 100644 index 3c833d8454b71fdd7b05088f73a44605a31b95f0..0000000000000000000000000000000000000000 --- a/dpacman/classifier/model/extract_tf_symbols.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env python3 -import pandas as pd -from pathlib import Path - -FINAL_CSV = Path("/home/a03-akrishna/DPACMAN/data_files/processed/final.csv") -OUT_SYMBOLS = Path("tf_symbols.txt") - -def normalize_tf(tf_id: str) -> str: - return tf_id.split("_seq")[0].upper() - -def main(): - df = pd.read_csv(FINAL_CSV, dtype=str) - if "TF_id" not in df.columns: - raise RuntimeError("final.csv missing TF_id column") - tf_raw = df["TF_id"].dropna().unique().tolist() - normalized = sorted({normalize_tf(t) for t in tf_raw}) - print(f"Unique raw TF_id count: {len(tf_raw)}") - print(f"Unique normalized TF symbols: {len(normalized)}") - with open(OUT_SYMBOLS, "w") as f: - for s in normalized: - f.write(s + "\n") - print(f"Wrote normalized TF symbols to {OUT_SYMBOLS}") - # Optional: show sample - print("Sample symbols:", normalized[:50]) - -if __name__ == "__main__": - main() diff --git a/dpacman/classifier/model/loss.py b/dpacman/classifier/model/loss.py deleted file mode 100644 index 845dee415686cfde286be33b1d83f38425e4605a..0000000000000000000000000000000000000000 --- a/dpacman/classifier/model/loss.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -Define loss functions needed for training the model -""" -import torch -from torch.nn import functional as F - -def combined_loss_components(logits, targets, peak_thresh=0.5, eps=1e-8): - probs = torch.sigmoid(logits) - labels = (targets >= peak_thresh).float() - non_peak_mask = (labels == 0).float() - peak_mask = (labels == 1).float() - - bce_all = F.binary_cross_entropy_with_logits(logits, labels, reduction='none') - bce_non = (bce_all * non_peak_mask) - bce_non = bce_non.sum() / (non_peak_mask.sum() + eps) - - mse_peaks = F.mse_loss(probs * peak_mask, targets * peak_mask, reduction='sum') \ - / (peak_mask.sum() + eps) - - t_dist = (targets + eps) - p_dist = (probs + eps) - t_dist = t_dist / t_dist.sum(dim=1, keepdim=True) - p_dist = p_dist / p_dist.sum(dim=1, keepdim=True) - kl = (t_dist * (t_dist.clamp(min=eps).log() - p_dist.clamp(min=eps).log())).sum(dim=1).mean() - - return bce_non, kl, mse_peaks, probs - -def accuracy_percentage(logits, targets, peak_thresh=0.5): - probs = torch.sigmoid(logits) - preds_bin = (probs >= 0.5).float() - labels = (targets >= peak_thresh).float() - correct = (preds_bin == labels).float().sum() - total = torch.numel(labels) - return (correct / max(1, total)).item() * 100.0 diff --git a/dpacman/classifier/model/make_pair_list.py b/dpacman/classifier/model/make_pair_list.py deleted file mode 100644 index 03d429b7585352d746e95811a0dfd73ef0a9331c..0000000000000000000000000000000000000000 --- a/dpacman/classifier/model/make_pair_list.py +++ /dev/null @@ -1,220 +0,0 @@ -#!/usr/bin/env python3 -import argparse -import numpy as np -import pandas as pd -from pathlib import Path -import random -import sys - -def read_ids_file(p): - p = Path(p) - if not p.exists(): - raise FileNotFoundError(f"IDs file not found: {p}") - return [line.strip() for line in p.open() if line.strip()] - -def split_embeddings(emb_path, ids_path, out_dir, prefix): - out_dir = Path(out_dir) - out_dir.mkdir(parents=True, exist_ok=True) - - if not Path(emb_path).exists(): - raise FileNotFoundError(f"Embedding file not found: {emb_path}") - if not Path(ids_path).exists(): - raise FileNotFoundError(f"IDs file not found: {ids_path}") - - if emb_path.endswith(".npz"): - data = np.load(emb_path, allow_pickle=True) - if "embeddings" in data: - emb = data["embeddings"] - else: - raise ValueError(f"{emb_path} missing 'embeddings' key") - else: - emb = np.load(emb_path) - - ids = read_ids_file(ids_path) - if len(ids) != emb.shape[0]: - print(f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", file=sys.stderr) - - mapping = {} - for i, ident in enumerate(ids): - if i >= emb.shape[0]: - print(f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr) - continue - arr = emb[i] - out_file = out_dir / f"{prefix}_{ident}.npy" - np.save(out_file, arr) - mapping[ident] = str(out_file) - return mapping - -def extract_symbol_from_tf_id(full_id: str) -> str: - """ - Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN', - return the gene symbol uppercase (e.g., 'ZBTB5'). - """ - if "|" in full_id: - try: - # format sp|Accession|SYMBOL_HUMAN - genepart = full_id.split("|")[2] - except IndexError: - genepart = full_id - else: - genepart = full_id - symbol = genepart.split("_")[0] - return symbol.upper() - -def build_tf_symbol_map(tf_map): - """ - Build mapping gene_symbol -> list of embedding paths. - """ - symbol_map = {} - for full_id, path in tf_map.items(): - symbol = extract_symbol_from_tf_id(full_id) - symbol_map.setdefault(symbol, []).append(path) - return symbol_map - -def tf_key_from_path(path: str) -> str: - """ - Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'. - """ - stem = Path(path).stem # e.g., tf_sp|O15062|ZBTB5_HUMAN - # remove leading prefix if present (tf_) - if "_" in stem: - _, rest = stem.split("_", 1) - else: - rest = stem - return extract_symbol_from_tf_id(rest) - -def dna_key_from_path(path: str) -> str: - """ - Given .../dna_peak42.npy -> 'peak42' - """ - stem = Path(path).stem - if "_" in stem: - _, rest = stem.split("_", 1) - else: - rest = stem - return rest - -def main(): - parser = argparse.ArgumentParser( - description="Build TF-DNA pair list from final.csv with gene-symbol normalization for TFs." - ) - parser.add_argument("--final_csv", required=True, help="final.csv with TF_id and dna_sequence") - parser.add_argument("--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)") - parser.add_argument("--dna_ids", required=True, help="IDs file for DNA embeddings (e.g., peak*.ids)") - parser.add_argument("--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)") - parser.add_argument("--tf_ids", required=True, help="IDs file for TF embeddings (e.g., sp|...|... ids)") - parser.add_argument("--out_dir", required=True, help="Output directory") - parser.add_argument("--neg_per_positive", type=int, default=2, help="Negatives per positive (half same-TF, half same-DNA)") - parser.add_argument("--seed", type=int, default=42) - args = parser.parse_args() - - random.seed(args.seed) - out_dir = Path(args.out_dir) - out_dir.mkdir(parents=True, exist_ok=True) - - # Load final.csv - df = pd.read_csv(args.final_csv, dtype=str) - if "TF_id" not in df.columns or "dna_sequence" not in df.columns: - raise RuntimeError("final.csv must have columns TF_id and dna_sequence") - - # Assign dna_id (unique per dna_sequence) - unique_seqs = df["dna_sequence"].drop_duplicates().tolist() - seq_to_id = {seq: f"peak{i}" for i, seq in enumerate(unique_seqs)} - df["dna_id"] = df["dna_sequence"].map(seq_to_id) - enriched_csv = out_dir / "final_with_dna_id.csv" - df.to_csv(enriched_csv, index=False) - print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}") - - # Split embeddings into per-item files - print(f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}") - dna_map = split_embeddings(args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna") - print(f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})") - print(f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}") - tf_map = split_embeddings(args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf") - print(f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})") - - # Build gene-symbol normalized map - tf_symbol_map = build_tf_symbol_map(tf_map) - print(f"[i] TF symbol map keys (sample): {list(tf_symbol_map.keys())[:30]}") - - # Diagnostic overlaps - norm_tf_in_final = set(t.split("_seq")[0].upper() for t in df["TF_id"].unique()) - available_tf_symbols = set(tf_symbol_map.keys()) - intersect_tf = norm_tf_in_final & available_tf_symbols - print(f"[i] Unique normalized TF symbols in final.csv: {len(norm_tf_in_final)}") - print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}") - print(f"[i] Intersection count: {len(intersect_tf)}") - if len(intersect_tf) == 0: - print("[ERROR] No overlap between normalized TF_id and TF embedding symbols.", file=sys.stderr) - print("Sample normalized TFs from final.csv:", sorted(list(norm_tf_in_final))[:30], file=sys.stderr) - print("Sample available TF symbols:", sorted(list(available_tf_symbols))[:30], file=sys.stderr) - sys.exit(1) - - dna_ids_final = set(df["dna_id"].unique()) - available_dna_ids = set(dna_map.keys()) - intersect_dna = dna_ids_final & available_dna_ids - print(f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}") - if len(intersect_dna) == 0: - print("[ERROR] No overlap on DNA ids.", file=sys.stderr) - sys.exit(1) - - # Build positive pairs - positives = [] - for _, row in df.iterrows(): - tf_raw = row["TF_id"] - tf_symbol = tf_raw.split("_seq")[0].upper() - dnaid = row["dna_id"] - if tf_symbol not in tf_symbol_map: - continue - if dnaid not in dna_map: - continue - # pick the first embedding for that symbol - tf_embedding_path = tf_symbol_map[tf_symbol][0] - positives.append((tf_embedding_path, dna_map[dnaid], 1)) - print(f"[i] Constructed {len(positives)} positive pairs after TF symbol resolution") - - if len(positives) == 0: - print("[ERROR] No positive pairs could be constructed; aborting.", file=sys.stderr) - sys.exit(1) - - # Build negative samples - all_tf_symbols = sorted(tf_symbol_map.keys()) - all_dnaids = sorted(dna_map.keys()) - positive_set = set() - for tf_path, dna_path, _ in positives: - tf_key = tf_key_from_path(tf_path) - dna_key = dna_key_from_path(dna_path) - positive_set.add((tf_key, dna_key)) - - negatives = [] - half = args.neg_per_positive // 2 - for tf_path, dna_path, _ in positives: - tf_key = tf_key_from_path(tf_path) - dna_key = dna_key_from_path(dna_path) - # same TF, different DNA - for _ in range(half): - candidate_dna = random.choice(all_dnaids) - if candidate_dna == dna_key or (tf_key, candidate_dna) in positive_set: - continue - negatives.append((tf_path, dna_map[candidate_dna], 0)) - # same DNA, different TF - for _ in range(half): - candidate_tf_symbol = random.choice(all_tf_symbols) - if candidate_tf_symbol == tf_key or (candidate_tf_symbol, dna_key) in positive_set: - continue - # pick its first embedding - candidate_tf_path = tf_symbol_map[candidate_tf_symbol][0] - negatives.append((candidate_tf_path, dna_map[dnaid], 0)) - - print(f"[i] Sampled {len(negatives)} negatives (neg_per_positive={args.neg_per_positive})") - - # Write pair list - pair_list_path = out_dir / "pair_list.tsv" - with open(pair_list_path, "w") as f: - for binder_path, glm_path, label in positives + negatives: - # binder=TF, glm=DNA - f.write(f"{binder_path}\t{glm_path}\t{label}\n") - print(f"[i] Wrote {len(positives)+len(negatives)} examples to {pair_list_path}") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/dpacman/classifier/model/make_peak_fasta.py b/dpacman/classifier/model/make_peak_fasta.py deleted file mode 100644 index 9c7000d1e7b8f2170b703d364085e4f8f899b0d2..0000000000000000000000000000000000000000 --- a/dpacman/classifier/model/make_peak_fasta.py +++ /dev/null @@ -1,13 +0,0 @@ -import pandas as pd -from pathlib import Path - -df = pd.read_csv("/home/a03-akrishna/DPACMAN/data_files/processed/final.csv", dtype=str) # adjust path if needed -# get unique sequences -uniq = df[["dna_sequence"]].drop_duplicates().reset_index(drop=True) -# make headers: e.g., peak0, peak1, ... -out_fa = Path("binding_peaks_unique.fa") -with open(out_fa, "w") as f: - for i, seq in enumerate(uniq["dna_sequence"]): - header = f">peak{i}" - f.write(f"{header}\n{seq}\n") -print(f"Wrote {len(uniq)} unique binding sequences to {out_fa}") \ No newline at end of file diff --git a/dpacman/classifier/model/train.py b/dpacman/classifier/model/train.py deleted file mode 100644 index 48ce1b4615a79ad802f87a424d984635beefc075..0000000000000000000000000000000000000000 --- a/dpacman/classifier/model/train.py +++ /dev/null @@ -1,304 +0,0 @@ -import argparse, random, sys -from pathlib import Path - -import numpy as np -import pandas as pd -import torch -from torch import nn -from torch.utils.data import Dataset, DataLoader, Sampler -# from sklearn.random_projection import GaussianRandomProjection # OLD (kept): projection was removed earlier -import matplotlib.pyplot as plt - -import torch.amp as amp -from torch.nn import functional as F -from model import BindPredictor - -# ─────────────── utilities ──────────────────────────────────────────────── -def parse_pair_list(path): - binders, glms = [], [] - with open(path) as f: - for ln, line in enumerate(f,1): - parts = line.strip().split() - if len(parts) < 2: continue - b,g = parts[0], parts[1] - binders.append(b); glms.append(g) - return binders, glms - -class ListBatchSampler(Sampler): - def __init__(self, batches): self.batches = batches - def __iter__(self): return iter(self.batches) - def __len__(self): return len(self.batches) - -def make_buckets(idxs, glm_paths, batch_size, n_buckets=10, seed=42): - rng = random.Random(seed) - lengths = [(i, np.load(glm_paths[i]).shape[0]) for i in idxs] - lengths.sort(key=lambda x: x[1]) - size = max(1, int(np.ceil(len(lengths)/n_buckets))) - buckets = [lengths[i:i+size] for i in range(0,len(lengths),size)] - batches = [] - for bucket in buckets: - ids = [i for i,_ in bucket] - rng.shuffle(ids) - for i in range(0,len(ids),batch_size): - batches.append(ids[i:i+batch_size]) - rng.shuffle(batches) - return batches - -def dna_key_from_path(path: str) -> str: - """.../dna_peak42.npy -> 'peak42'""" - stem = Path(path).stem - if "_" in stem: - _, rest = stem.split("_", 1) - else: - rest = stem - return rest - -def build_tf_cache(tf_paths, target_dim=256): - """ - Load raw TF embeddings without projecting; compression is learnable in the model. - """ - unique = sorted(set(tf_paths)) - print(f"[i] (Learnable) Preparing {len(unique)} TF files; target {target_dim}d inside the model", flush=True) - - pools, raw = [], [] - for p in unique: - arr = np.load(p) # (L, D) or (D,) - raw.append(arr) - pools.append(arr.mean(axis=0) if arr.ndim==2 else arr) - M = np.stack(pools,0) - orig_dim = M.shape[1] - print(f"[i] Pooled shape → {M.shape} (orig_dim={orig_dim})", flush=True) - - cache = {} - for i,p in enumerate(unique): - arr = raw[i] - # OLD: projection here (removed) - cache[p] = arr - print("[i] TF cache ready (raw); compression will be learned.", flush=True) - return cache - -def evaluate(model, dl, device, alpha, beta, gamma, peak_thresh, eps=1e-8): - model.eval() - tot_loss, tot_acc = 0.0, 0.0 - n_batches = 0 - with torch.no_grad(): - for b,g,t in dl: - b,g,t = b.to(device), g.to(device), t.to(device) - logits = model(b,g) - bce_non, kl, mse_peaks, _ = combined_loss_components(logits, t, peak_thresh=peak_thresh, eps=eps) - loss = alpha*bce_non + beta*kl + gamma*mse_peaks - acc = accuracy_percentage(logits, t, peak_thresh=peak_thresh) - tot_loss += loss.item(); tot_acc += acc; n_batches += 1 - if n_batches == 0: return float("nan"), float("nan") - return tot_loss / n_batches, tot_acc / n_batches - -# ─────────────── cluster-aware splitting ────────────────────────────────── -def assign_clusters_to_splits(cluster_to_indices, val_frac=0.10, test_frac=0.10, seed=42): - """ - cluster_to_indices: dict[cluster_id] -> list of example indices (from pair_list) in that cluster - We greedily pack whole clusters into val/test until hitting targets (#examples), rest to train. - """ - rng = random.Random(seed) - clusters = list(cluster_to_indices.items()) - rng.shuffle(clusters) - - total = sum(len(ixs) for _, ixs in clusters) - target_val = int(round(total * val_frac)) - target_test = int(round(total * test_frac)) - cur_val = cur_test = 0 - - tr_ix, va_ix, te_ix = [], [], [] - for cid, ixs in clusters: - c = len(ixs) - if cur_val + c <= target_val: - va_ix.extend(ixs); cur_val += c - elif cur_test + c <= target_test: - te_ix.extend(ixs); cur_test += c - else: - tr_ix.extend(ixs) - return tr_ix, va_ix, te_ix - -# ─────────────── train & main ──────────────────────────────────────────── -def main(): - p = argparse.ArgumentParser() - p.add_argument("--pair_list", required=True) - p.add_argument("--final_csv", required=True) - p.add_argument("--out_dir", required=True) - p.add_argument("--epochs", type=int, default=10) - p.add_argument("--batch_size", type=int, default=16) - p.add_argument("--accum_steps", type=int, default=4) - p.add_argument("--lr", type=float, default=1e-4) - p.add_argument("--device", default="cuda") - p.add_argument("--seed", type=int, default=42) - p.add_argument("--alpha", type=float, default=0.5) - p.add_argument("--beta", type=float, default=0.6) - p.add_argument("--gamma", type=float, default=0.6) - p.add_argument("--peak_thresh", type=float, default=0.5) - # NEW: fractions for cluster-aware split (used only if cluster_id present) - p.add_argument("--val_frac", type=float, default=0.10) - p.add_argument("--test_frac", type=float, default=0.10) - args = p.parse_args() - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - device = torch.device(args.device if torch.cuda.is_available() else "cpu") - - # 1) load pair list & final.csv (now may include cluster_id) - tf_paths, dna_paths = parse_pair_list(args.pair_list) - final_df = pd.read_csv(args.final_csv, dtype=str) - print(f"[i] Loaded {len(tf_paths)} pairs", flush=True) - - tf_cache = build_tf_cache(tf_paths, target_dim=256) - - # detect binder/DNA dims - sample_tf = tf_cache[tf_paths[0]] - binder_input_dim = sample_tf.shape[1] if sample_tf.ndim == 2 else sample_tf.shape[0] - glm_input_dim = 256 - - # 2) cluster-aware split if possible - use_cluster_split = ("cluster_id" in final_df.columns) - if use_cluster_split: - print("[i] Cluster column detected in final_csv; performing cluster-aware split.", flush=True) - # build dna_id -> cluster_id map - cid_map = (final_df[["dna_id","cluster_id"]].dropna().drop_duplicates() - .set_index("dna_id")["cluster_id"].to_dict()) - - # map each example (by index) to its dna_id and cluster - example_dna_ids = [dna_key_from_path(p) for p in dna_paths] - example_clusters = [] - missing = 0 - for did in example_dna_ids: - if did in cid_map: - example_clusters.append(cid_map[did]) - else: - # fallback: treat singleton cluster - example_clusters.append(f"singleton::{did}") - missing += 1 - if missing: - print(f"[WARN] {missing} dna_ids from pair_list not found in cluster map; treating as singleton clusters.", flush=True) - - # build cluster -> indices - cluster_to_indices = {} - for i, cid in enumerate(example_clusters): - cluster_to_indices.setdefault(cid, []).append(i) - - tr_idx, va_idx, te_idx = assign_clusters_to_splits( - cluster_to_indices, - val_frac=args.val_frac, test_frac=args.test_frac, seed=args.seed - ) - print(f"[i] Cluster split sizes (examples): train={len(tr_idx)} val={len(va_idx)} test={len(te_idx)}", flush=True) - - # helper to subset paths - def subset_by_indices(ixs): - return [tf_paths[i] for i in ixs], [dna_paths[i] for i in ixs] - - tr_t, tr_d = subset_by_indices(tr_idx) - va_t, va_d = subset_by_indices(va_idx) - te_t, te_d = subset_by_indices(te_idx) - - else: - print("[i] No cluster_id in final_csv; using random 80/10/10 split (OLD behavior).", flush=True) - # OLD random split (kept, now under else) - N = len(tf_paths) - idxs = list(range(N)); random.shuffle(idxs) - n_tr = int(0.8*N); n_va = int(0.1*N) - tr, va, te = idxs[:n_tr], idxs[n_tr:n_tr+n_va], idxs[n_tr+n_va:] - - def subset(idxs_): - return [tf_paths[i] for i in idxs_], [dna_paths[i] for i in idxs_] - - tr_t, tr_d = subset(tr) - va_t, va_d = subset(va) - te_t, te_d = subset(te) - - # 3) bucketed samplers (unchanged, but now use the cluster-aware subsets when available) - tr_bs = make_buckets(list(range(len(tr_t))), tr_d, args.batch_size, n_buckets=10, seed=args.seed) - va_bs = make_buckets(list(range(len(va_t))), va_d, args.batch_size, n_buckets=5, seed=args.seed+1) - te_bs = make_buckets(list(range(len(te_t))), te_d, args.batch_size, n_buckets=5, seed=args.seed+2) - - tr_dl = DataLoader(PairDataset(tr_t, tr_d, final_df, tf_cache), - batch_sampler=ListBatchSampler(tr_bs), - collate_fn=collate_fn) - va_dl = DataLoader(PairDataset(va_t, va_d, final_df, tf_cache), - batch_sampler=ListBatchSampler(va_bs), - collate_fn=collate_fn) - te_dl = DataLoader(PairDataset(te_t, te_d, final_df, tf_cache), - batch_sampler=ListBatchSampler(te_bs), - collate_fn=collate_fn) - - # 4) model, optimizer, scaler - model = BindPredictor(binder_input_dim=binder_input_dim, - glm_input_dim=glm_input_dim, - compressed_dim=256, - hidden_dim=256, - heads=8, num_layers=4, - use_local_cnn_on_glm=True).to(device) - optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) - scaler = amp.GradScaler('cuda') - - history, best_val = {"train": [], "val": []}, float("inf") - od = Path(args.out_dir); od.mkdir(exist_ok=True, parents=True) - - for ep in range(1, args.epochs+1): - print(f"┌─[Epoch {ep}]────────────────────────", flush=True) - model.train() - optimizer.zero_grad() - acc_loss_sum, acc_acc_sum, n_train_batches = 0.0, 0.0, 0 - - for i, (b, g, t) in enumerate(tr_dl): - b, g, t = b.to(device), g.to(device), t.to(device) - with amp.autocast('cuda'): - logits = model(b, g) - bce_non, kl, mse_peaks, probs = combined_loss_components( - logits, t, peak_thresh=args.peak_thresh - ) - loss = args.alpha*bce_non + args.beta*kl + args.gamma*mse_peaks - loss = loss / args.accum_steps - - scaler.scale(loss).backward() - - if (i + 1) % args.accum_steps == 0: - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - - with torch.no_grad(): - acc_loss_sum += (loss.item() * args.accum_steps) - acc_acc_sum += accuracy_percentage(logits, t, peak_thresh=args.peak_thresh) - n_train_batches += 1 - - del b, g, t, logits, probs, loss, bce_non, kl, mse_peaks - torch.cuda.empty_cache() - - # finalize if leftovers - if n_train_batches % args.accum_steps != 0: - scaler.step(optimizer); scaler.update(); optimizer.zero_grad() - - train_loss = acc_loss_sum / max(1, n_train_batches) - train_acc = acc_acc_sum / max(1, n_train_batches) - - val_loss, val_acc = evaluate(model, va_dl, device, - alpha=args.alpha, beta=args.beta, gamma=args.gamma, - peak_thresh=args.peak_thresh) - print(f"[Epoch {ep}] train_loss={train_loss:.4f} train_acc={train_acc:.2f}% " - f"val_loss={val_loss:.4f} val_acc={val_acc:.2f}%", flush=True) - - history["train"].append(train_loss) - history["val"].append(val_loss) - if val_loss < best_val: - best_val = val_loss - torch.save(model.state_dict(), od/"best_model.pt") - print(f" Saved new best_model.pt (val_loss={val_loss:.4f}, val_acc={val_acc:.2f}%)", flush=True) - - torch.save(model.state_dict(), od/"last_model.pt") - - fig, ax = plt.subplots() - ax.plot(history["train"], label="train") - ax.plot(history["val"], label="val") - ax.set_xlabel("epoch"); ax.set_ylabel("combined loss"); ax.legend() - fig.savefig(od/"loss_curve.png") - print(f"✅ Done → outputs in {od}", flush=True) - -if __name__=="__main__": - main() diff --git a/dpacman/classifier/model_tmp/clustering_data.py b/dpacman/classifier/model_tmp/clustering_data.py index a8276866cc86ed600841a5c679efc797418a89d3..a7958e2c796df850c7772dc9622f7e30b4b05ed5 100644 --- a/dpacman/classifier/model_tmp/clustering_data.py +++ b/dpacman/classifier/model_tmp/clustering_data.py @@ -12,12 +12,14 @@ from collections import defaultdict # Original helpers (kept; some lightly edited/commented where needed) # ───────────────────────────────────────────────────────────────────────── + def read_ids_file(p): p = Path(p) if not p.exists(): raise FileNotFoundError(f"IDs file not found: {p}") return [line.strip() for line in p.open() if line.strip()] + def split_embeddings(emb_path, ids_path, out_dir, prefix): out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) @@ -38,12 +40,17 @@ def split_embeddings(emb_path, ids_path, out_dir, prefix): ids = read_ids_file(ids_path) if len(ids) != emb.shape[0]: - print(f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", file=sys.stderr) + print( + f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", + file=sys.stderr, + ) mapping = {} for i, ident in enumerate(ids): if i >= emb.shape[0]: - print(f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr) + print( + f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr + ) continue arr = emb[i] out_file = out_dir / f"{prefix}_{ident}.npy" @@ -51,6 +58,7 @@ def split_embeddings(emb_path, ids_path, out_dir, prefix): mapping[ident] = str(out_file) return mapping + def extract_symbol_from_tf_id(full_id: str) -> str: """ Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN', @@ -67,6 +75,7 @@ def extract_symbol_from_tf_id(full_id: str) -> str: symbol = genepart.split("_")[0] return symbol.upper() + def build_tf_symbol_map(tf_map): """ Build mapping gene_symbol -> list of embedding paths. @@ -77,6 +86,7 @@ def build_tf_symbol_map(tf_map): symbol_map.setdefault(symbol, []).append(path) return symbol_map + def tf_key_from_path(path: str) -> str: """ Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'. @@ -89,6 +99,7 @@ def tf_key_from_path(path: str) -> str: rest = stem return extract_symbol_from_tf_id(rest) + def dna_key_from_path(path: str) -> str: """ Given .../dna_peak42.npy -> 'peak42' @@ -100,10 +111,12 @@ def dna_key_from_path(path: str) -> str: rest = stem return rest + # ───────────────────────────────────────────────────────────────────────── # New helpers for MMseqs clustering & cluster-level splitting # ───────────────────────────────────────────────────────────────────────── + def write_dna_fasta(df: pd.DataFrame, out_fasta: Path) -> None: """ Write unique DNA sequences to FASTA using dna_id as header. @@ -116,6 +129,7 @@ def write_dna_fasta(df: pd.DataFrame, out_fasta: Path) -> None: seq = str(row["dna_sequence"]).upper().replace(" ", "").replace("\n", "") f.write(f">{did}\n{seq}\n") + def run_mmseqs_easy_cluster( mmseqs_bin: str, fasta: Path, @@ -133,11 +147,17 @@ def run_mmseqs_easy_cluster( out_prefix.parent.mkdir(parents=True, exist_ok=True) cmd = [ - mmseqs_bin, "easy-cluster", - str(fasta), str(out_prefix), str(tmp_dir), - "--min-seq-id", str(min_seq_id), - "-c", str(coverage), - "--cov-mode", str(cov_mode), + mmseqs_bin, + "easy-cluster", + str(fasta), + str(out_prefix), + str(tmp_dir), + "--min-seq-id", + str(min_seq_id), + "-c", + str(coverage), + "--cov-mode", + str(cov_mode), # You can add performance flags here if needed, e.g.: # "--threads", "8" ] @@ -157,14 +177,24 @@ def run_mmseqs_easy_cluster( cl_db = Path(str(out_prefix) + "_cluster") out_tsv = Path(str(out_prefix) + "_fallback_cluster.tsv") if in_db.exists() and cl_db.exists(): - cmd2 = [mmseqs_bin, "createtsv", str(in_db), str(in_db), str(cl_db), str(out_tsv)] + cmd2 = [ + mmseqs_bin, + "createtsv", + str(in_db), + str(in_db), + str(cl_db), + str(out_tsv), + ] print("[i] Creating TSV via createtsv:", " ".join(cmd2), flush=True) subprocess.run(cmd2, check=True) if out_tsv.exists(): return out_tsv - raise FileNotFoundError("Could not locate clusters TSV from mmseqs. " - "Expected {default_tsv} or createtsv fallback.") + raise FileNotFoundError( + "Could not locate clusters TSV from mmseqs. " + "Expected {default_tsv} or createtsv fallback." + ) + def parse_mmseqs_clusters(tsv_path: Path) -> dict: """ @@ -174,7 +204,7 @@ def parse_mmseqs_clusters(tsv_path: Path) -> dict: with open(tsv_path) as f: for line in f: parts = line.rstrip("\n").split("\t") - if len(parts) < 2: + if len(parts) < 2: continue rep, member = parts[0], parts[1] mapping[member] = rep @@ -183,10 +213,10 @@ def parse_mmseqs_clusters(tsv_path: Path) -> dict: mapping[rep] = rep return mapping -def assign_clusters_to_splits(cluster_rep_to_members: dict, - val_frac: float, - test_frac: float, - seed: int = 42): + +def assign_clusters_to_splits( + cluster_rep_to_members: dict, val_frac: float, test_frac: float, seed: int = 42 +): """ cluster_rep_to_members: dict[rep] = [members...] Returns: dict with keys 'train','val','test' mapping to sets of dna_id. @@ -208,38 +238,63 @@ def assign_clusters_to_splits(cluster_rep_to_members: dict, c = len(members) # Fill val first, then test, then train if cur_val + c <= target_val: - val_ids.update(members); cur_val += c + val_ids.update(members) + cur_val += c elif cur_test + c <= target_test: - test_ids.update(members); cur_test += c + test_ids.update(members) + cur_test += c else: train_ids.update(members) return {"train": train_ids, "val": val_ids, "test": test_ids} + # ───────────────────────────────────────────────────────────────────────── # Main # ───────────────────────────────────────────────────────────────────────── + def main(): parser = argparse.ArgumentParser( description="Build TF-DNA pair lists with MMseqs clustering on DNA to prevent split leakage." ) - parser.add_argument("--final_csv", required=True, help="final.csv with TF_id and dna_sequence") - parser.add_argument("--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)") - parser.add_argument("--dna_ids", required=True, help="IDs file for DNA embeddings (peak*.ids)") - parser.add_argument("--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)") - parser.add_argument("--tf_ids", required=True, help="IDs file for TF embeddings (sp|... ids)") + parser.add_argument( + "--final_csv", required=True, help="final.csv with TF_id and dna_sequence" + ) + parser.add_argument( + "--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)" + ) + parser.add_argument( + "--dna_ids", required=True, help="IDs file for DNA embeddings (peak*.ids)" + ) + parser.add_argument( + "--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)" + ) + parser.add_argument( + "--tf_ids", required=True, help="IDs file for TF embeddings (sp|... ids)" + ) parser.add_argument("--out_dir", required=True, help="Output directory") parser.add_argument("--seed", type=int, default=42) # NEW: MMseqs options & split fractions parser.add_argument("--mmseqs_bin", default="mmseqs", help="Path to mmseqs binary") - parser.add_argument("--min_seq_id", type=float, default=0.9, help="MMseqs --min-seq-id") - parser.add_argument("--cov", type=float, default=0.8, help="MMseqs -c coverage fraction") - parser.add_argument("--cov_mode", type=int, default=1, help="MMseqs --cov-mode (1 = coverage of target)") + parser.add_argument( + "--min_seq_id", type=float, default=0.9, help="MMseqs --min-seq-id" + ) + parser.add_argument( + "--cov", type=float, default=0.8, help="MMseqs -c coverage fraction" + ) + parser.add_argument( + "--cov_mode", + type=int, + default=1, + help="MMseqs --cov-mode (1 = coverage of target)", + ) parser.add_argument("--val_frac", type=float, default=0.10) parser.add_argument("--test_frac", type=float, default=0.10) - parser.add_argument("--tmp_dir", default=None, help="MMseqs tmp dir (defaults to out_dir/tmp)") + parser.add_argument( + "--tmp_dir", default=None, help="MMseqs tmp dir (defaults to out_dir/tmp)" + ) args = parser.parse_args() random.seed(args.seed) @@ -260,12 +315,24 @@ def main(): print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}") # Split embeddings into per-item files (unchanged) - print(f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}") - dna_map = split_embeddings(args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna") - print(f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})") - print(f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}") - tf_map = split_embeddings(args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf") - print(f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})") + print( + f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}" + ) + dna_map = split_embeddings( + args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna" + ) + print( + f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})" + ) + print( + f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}" + ) + tf_map = split_embeddings( + args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf" + ) + print( + f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})" + ) # Build gene-symbol normalized map tf_symbol_map = build_tf_symbol_map(tf_map) @@ -279,15 +346,28 @@ def main(): print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}") print(f"[i] Intersection count: {len(intersect_tf)}") if len(intersect_tf) == 0: - print("[ERROR] No overlap between normalized TF_id and TF embedding symbols.", file=sys.stderr) - print("Sample normalized TFs from final.csv:", sorted(list(norm_tf_in_final))[:30], file=sys.stderr) - print("Sample available TF symbols:", sorted(list(available_tf_symbols))[:30], file=sys.stderr) + print( + "[ERROR] No overlap between normalized TF_id and TF embedding symbols.", + file=sys.stderr, + ) + print( + "Sample normalized TFs from final.csv:", + sorted(list(norm_tf_in_final))[:30], + file=sys.stderr, + ) + print( + "Sample available TF symbols:", + sorted(list(available_tf_symbols))[:30], + file=sys.stderr, + ) sys.exit(1) dna_ids_final = set(df["dna_id"].unique()) available_dna_ids = set(dna_map.keys()) intersect_dna = dna_ids_final & available_dna_ids - print(f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}") + print( + f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}" + ) if len(intersect_dna) == 0: print("[ERROR] No overlap on DNA ids.", file=sys.stderr) sys.exit(1) @@ -295,7 +375,9 @@ def main(): # ── NEW: MMseqs clustering on DNA sequences ─────────────────────────── fasta_path = out_dir / "dna_unique.fasta" write_dna_fasta(df, fasta_path) - print(f"[i] Wrote FASTA with {df['dna_id'].nunique()} unique sequences → {fasta_path}") + print( + f"[i] Wrote FASTA with {df['dna_id'].nunique()} unique sequences → {fasta_path}" + ) tmp_dir = Path(args.tmp_dir) if args.tmp_dir else (out_dir / "mmseqs_tmp") cluster_prefix = out_dir / "mmseqs_dna_clusters" @@ -310,7 +392,7 @@ def main(): ) # Parse clusters - member_to_rep = parse_mmseqs_clusters(clusters_tsv) # dna_id -> rep_id + member_to_rep = parse_mmseqs_clusters(clusters_tsv) # dna_id -> rep_id # Build rep -> members list rep_to_members = defaultdict(list) for member, rep in member_to_rep.items(): @@ -331,10 +413,9 @@ def main(): print(f"[i] Wrote {out_dir / 'final_with_dna_id_and_cluster.csv'}") # Assign entire clusters to splits - splits = assign_clusters_to_splits(rep_to_members, - val_frac=args.val_frac, - test_frac=args.test_frac, - seed=args.seed) + splits = assign_clusters_to_splits( + rep_to_members, val_frac=args.val_frac, test_frac=args.test_frac, seed=args.seed + ) for k in ["train", "val", "test"]: print(f"[i] {k}: {len(splits[k])} dna_ids") @@ -354,14 +435,22 @@ def main(): # decide split by dna_id cluster assignment if dnaid in splits["train"]: - positives_by_split["train"].append((tf_embedding_path, dnaid_to_path[dnaid], 1)) + positives_by_split["train"].append( + (tf_embedding_path, dnaid_to_path[dnaid], 1) + ) elif dnaid in splits["val"]: - positives_by_split["val"].append((tf_embedding_path, dnaid_to_path[dnaid], 1)) + positives_by_split["val"].append( + (tf_embedding_path, dnaid_to_path[dnaid], 1) + ) elif dnaid in splits["test"]: - positives_by_split["test"].append((tf_embedding_path, dnaid_to_path[dnaid], 1)) + positives_by_split["test"].append( + (tf_embedding_path, dnaid_to_path[dnaid], 1) + ) pos_count += 1 - print(f"[i] Constructed positives across splits (rows in final.csv iterated: {len(df)})") + print( + f"[i] Constructed positives across splits (rows in final.csv iterated: {len(df)})" + ) for k in ["train", "val", "test"]: print(f"[i] positives[{k}] = {len(positives_by_split[k])}") @@ -373,11 +462,14 @@ def main(): for split in ["train", "val", "test"]: out_tsv = out_dir / f"pair_list_{split}.tsv" with open(out_tsv, "w") as f: - for binder_path, glm_path, label in positives_by_split[split]: # + negatives if you add later + for binder_path, glm_path, label in positives_by_split[ + split + ]: # + negatives if you add later f.write(f"{binder_path}\t{glm_path}\t{label}\n") print(f"[i] Wrote {len(positives_by_split[split])} examples to {out_tsv}") print("✅ Done. Cluster-aware splits ready.") + if __name__ == "__main__": main() diff --git a/dpacman/classifier/model_tmp/compress_embeddings.py b/dpacman/classifier/model_tmp/compress_embeddings.py index 248f0143ff2430b3c244b856cd853895c51d3552..b36fdbf284cf111894b97c72c07820e0afe996e4 100644 --- a/dpacman/classifier/model_tmp/compress_embeddings.py +++ b/dpacman/classifier/model_tmp/compress_embeddings.py @@ -7,6 +7,7 @@ import numpy as np import torch from torch import nn + class EmbeddingCompressor(nn.Module): def __init__(self, input_dim: int = 1280, output_dim: int = 256): super().__init__() @@ -19,26 +20,33 @@ class EmbeddingCompressor(nn.Module): """ if x.dim() == 2: # single example: mean over tokens - x = x.mean(dim=0, keepdim=True) # → (1, input_dim) + x = x.mean(dim=0, keepdim=True) # → (1, input_dim) else: # batch: mean over tokens - x = x.mean(dim=1) # → (batch, input_dim) - return self.fc(x) # → (batch, output_dim) + x = x.mean(dim=1) # → (batch, input_dim) + return self.fc(x) # → (batch, output_dim) + def compress_file(in_path: str, out_path: str, model: EmbeddingCompressor): - arr = np.load(in_path) # shape (L, D) or (batch, L, D) + arr = np.load(in_path) # shape (L, D) or (batch, L, D) tensor = torch.from_numpy(arr).float() with torch.no_grad(): - compressed = model(tensor) # → (batch, 256) + compressed = model(tensor) # → (batch, 256) out = compressed.cpu().numpy() np.save(out_path, out) print(f"Saved {out_path}") + if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser(description="Compress ESM embeddings to 256­d") - parser.add_argument("--input_glob", type=str, required=True, - help="Glob for your .npy ESM embeddings (e.g. data/esm_*.npy)") + parser.add_argument( + "--input_glob", + type=str, + required=True, + help="Glob for your .npy ESM embeddings (e.g. data/esm_*.npy)", + ) parser.add_argument("--output_dir", type=str, required=True) parser.add_argument("--esm_dim", type=int, default=1280) parser.add_argument("--out_dim", type=int, default=256) diff --git a/dpacman/classifier/model_tmp/compute_embeddings.py b/dpacman/classifier/model_tmp/compute_embeddings.py index 5646d3f0d5edf66dc7a7d2869cad780e3f49441c..d4ce3e053386ae8db3f360d0e3470edacc016240 100644 --- a/dpacman/classifier/model_tmp/compute_embeddings.py +++ b/dpacman/classifier/model_tmp/compute_embeddings.py @@ -14,6 +14,7 @@ Usage example (DNA + protein in one go): --out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \ --device cuda """ + import os import re import argparse @@ -28,6 +29,7 @@ import time # ---- model wrappers ---- + class CaduceusEmbedder: def __init__(self, device, chunk_size=131_072, overlap=0): """ @@ -39,12 +41,14 @@ class CaduceusEmbedder: self.tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True ) - self.model = AutoModel.from_pretrained( - model_name, trust_remote_code=True - ).to(device).eval() - self.device = device + self.model = ( + AutoModel.from_pretrained(model_name, trust_remote_code=True) + .to(device) + .eval() + ) + self.device = device self.chunk_size = chunk_size - self.step = chunk_size - overlap + self.step = chunk_size - overlap def embed(self, seqs): """ @@ -73,14 +77,13 @@ class CaduceusEmbedder: return_tensors="pt", padding=False, truncation=True, - max_length=self.chunk_size + max_length=self.chunk_size, ).to(self.device) with torch.no_grad(): out = self.model(**toks).last_hidden_state # (1, L, D) - outputs.append(out.cpu().numpy()[0]) # (L, D) + outputs.append(out.cpu().numpy()[0]) # (L, D) return outputs # list of variable-length (L_i, D) arrays - def benchmark(self, lengths=None): """ Time embedding on single-sequence of various lengths. @@ -101,10 +104,17 @@ class CaduceusEmbedder: t1 = time.perf_counter() print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms") + class SegmentNTEmbedder: def __init__(self, device): - self.tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True) - self.model = AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True).to(device).eval() + self.tokenizer = AutoTokenizer.from_pretrained( + "InstaDeepAI/segment_nt", trust_remote_code=True + ) + self.model = ( + AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True) + .to(device) + .eval() + ) self.device = device def _adjust_length(self, input_ids): @@ -113,7 +123,12 @@ class SegmentNTEmbedder: remainder = (excl) % 4 if remainder != 0: pad_needed = 4 - remainder - pad_tensor = torch.full((bs, pad_needed), self.tokenizer.pad_token_id, dtype=input_ids.dtype, device=input_ids.device) + pad_tensor = torch.full( + (bs, pad_needed), + self.tokenizer.pad_token_id, + dtype=input_ids.dtype, + device=input_ids.device, + ) input_ids = torch.cat([input_ids, pad_tensor], dim=1) return input_ids @@ -135,7 +150,7 @@ class SegmentNTEmbedder: attention_mask = input_ids != self.tokenizer.pad_token_id input_ids = self._adjust_length(input_ids) - attention_mask = (input_ids != self.tokenizer.pad_token_id) + attention_mask = input_ids != self.tokenizer.pad_token_id with torch.no_grad(): outs = self.model( @@ -161,19 +176,26 @@ class SegmentNTEmbedder: class DNABertEmbedder: def __init__(self, device): - self.tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True) - self.model = AutoModel.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True).to(device) - self.device = device + self.tokenizer = AutoTokenizer.from_pretrained( + "zhihan1996/DNA_bert_6", trust_remote_code=True + ) + self.model = AutoModel.from_pretrained( + "zhihan1996/DNA_bert_6", trust_remote_code=True + ).to(device) + self.device = device def embed(self, seqs): embs = [] for s in seqs: - tokens = self.tokenizer(s, return_tensors="pt", padding=True)["input_ids"].to(self.device) + tokens = self.tokenizer(s, return_tensors="pt", padding=True)[ + "input_ids" + ].to(self.device) with torch.no_grad(): out = self.model(tokens).last_hidden_state.mean(1) embs.append(out.cpu().numpy()) return np.vstack(embs) + class NucleotideTransformerEmbedder: def __init__(self, device): # HF “feature-extraction” returns a list of (L, D) arrays for each input @@ -181,7 +203,9 @@ class NucleotideTransformerEmbedder: self.pipe = pipeline( "feature-extraction", model="InstaDeepAI/nucleotide-transformer-500m-1000g", - device= -1 if device=="cpu" else 0 # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0} + device=( + -1 if device == "cpu" else 0 + ), # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0} ) def embed(self, seqs): @@ -191,8 +215,9 @@ class NucleotideTransformerEmbedder: """ all_embeddings = self.pipe(seqs, truncation=True, padding=True) # all_embeddings is a List of shape (L, D) arrays - pooled = [ np.mean(x, axis=0) for x in all_embeddings ] - return np.vstack(pooled) + pooled = [np.mean(x, axis=0) for x in all_embeddings] + return np.vstack(pooled) + # class ESMEmbedder: # def __init__(self, device): @@ -225,7 +250,9 @@ class ESMEmbedder: self.batch_converter = self.alphabet.get_batch_converter() self.model.to(device).eval() # determine max length: esm2 models vary; use default 1024 for esm1b - self.max_len = 4096 if self.is_esm2 else 1024 # adjust if your esm2 variant has explicit limit + self.max_len = ( + 4096 if self.is_esm2 else 1024 + ) # adjust if your esm2 variant has explicit limit # for chunking: reserve 2 tokens if model uses BOS/EOS self.chunk_size = self.max_len - 2 self.overlap = self.chunk_size // 4 # 25% overlap to smooth boundaries @@ -280,7 +307,7 @@ class ESMEmbedder: # class ESMDBPEmbedder: # def __init__(self, device): -# base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S() +# base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S() # model_path = ( # Path(__file__).resolve().parent.parent # / "pretrained" / "ESM-DBP" / "ESM-DBP.model" @@ -310,12 +337,15 @@ class ESMEmbedder: # # skip start/end tokens # return reps[:, 1:-1].mean(1).cpu().numpy() + class ESMDBPEmbedder: def __init__(self, device): base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S() model_path = ( Path(__file__).resolve().parent.parent - / "pretrained" / "ESM-DBP" / "ESM-DBP.model" + / "pretrained" + / "ESM-DBP" + / "ESM-DBP.model" ) checkpoint = torch.load(model_path, map_location="cpu") clean_sd = {} @@ -372,6 +402,7 @@ class ESMDBPEmbedder: all_embeddings.append(seq_vec.cpu().numpy()) return np.vstack(all_embeddings) + class GPNEmbedder: def __init__(self, device): model_name = "songlab/gpn-msa-sapiens" @@ -383,16 +414,14 @@ class GPNEmbedder: def embed(self, seqs): inputs = self.tokenizer( - seqs, - return_tensors="pt", - padding=True, - truncation=True + seqs, return_tensors="pt", padding=True, truncation=True ).to(self.device) with torch.no_grad(): last_hidden = self.model(**inputs).last_hidden_state return last_hidden.mean(dim=1).cpu().numpy() + class ProGenEmbedder: def __init__(self, device): model_name = "jinyuan22/ProGen2-base" @@ -402,29 +431,36 @@ class ProGenEmbedder: def embed(self, seqs): inputs = self.tokenizer( - seqs, - return_tensors="pt", - padding=True, - truncation=True + seqs, return_tensors="pt", padding=True, truncation=True ).to(self.device) with torch.no_grad(): last_hidden = self.model(**inputs).last_hidden_state return last_hidden.mean(dim=1).cpu().numpy() + # ---- main pipeline ---- + def get_embedder(name, device, for_dna=True): name = name.lower() if for_dna: - if name=="caduceus": return CaduceusEmbedder(device) - if name=="dnabert": return DNABertEmbedder(device) - if name=="nucleotide": return NucleotideTransformerEmbedder(device) - if name=="gpn": return GPNEmbedder(device) - if name=="segmentnt": return SegmentNTEmbedder(device) + if name == "caduceus": + return CaduceusEmbedder(device) + if name == "dnabert": + return DNABertEmbedder(device) + if name == "nucleotide": + return NucleotideTransformerEmbedder(device) + if name == "gpn": + return GPNEmbedder(device) + if name == "segmentnt": + return SegmentNTEmbedder(device) else: - if name in ("esm",): return ESMEmbedder(device) - if name in ("esm-dbp","esm_dbp"): return ESMDBPEmbedder(device) - if name=="progen": return ProGenEmbedder(device) + if name in ("esm",): + return ESMEmbedder(device) + if name in ("esm-dbp", "esm_dbp"): + return ESMDBPEmbedder(device) + if name == "progen": + return ProGenEmbedder(device) raise ValueError(f"Unknown model {name} (for_dna={for_dna})") @@ -446,20 +482,28 @@ def pad_token_embeddings(list_of_arrays, pad_value=0.0): mask[i, :L] = True return padded, mask + def embed_and_save(seqs, ids, embedder, out_path): embs = embedder.embed(seqs) # Decide whether we got variable-length per-token outputs (list of (L, D)) - is_variable_token = isinstance(embs, (list, tuple)) and len(embs) > 0 and hasattr(embs[0], "shape") and embs[0].ndim == 2 + is_variable_token = ( + isinstance(embs, (list, tuple)) + and len(embs) > 0 + and hasattr(embs[0], "shape") + and embs[0].ndim == 2 + ) if is_variable_token: # pad to (N, L_max, D) + mask padded, mask = pad_token_embeddings(embs) # Save both embeddings and mask together in an .npz for convenience - np.savez_compressed(out_path.with_suffix(".caduceus.npz"), - embeddings=padded, - mask=mask, - ids=np.array(ids, dtype=object)) + np.savez_compressed( + out_path.with_suffix(".caduceus.npz"), + embeddings=padded, + mask=mask, + ids=np.array(ids, dtype=object), + ) else: # fixed shape output, e.g., pooled (N, D) array = np.vstack(embs) if isinstance(embs, list) else embs @@ -468,17 +512,31 @@ def embed_and_save(seqs, ids, embedder, out_path): f.write("\n".join(ids)) -if __name__=="__main__": +if __name__ == "__main__": p = argparse.ArgumentParser() - p.add_argument("--peak-fasta", default="binding_peaks_unique.fa", help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs") - p.add_argument("--genome-json-dir", default=None, help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes") - p.add_argument("--skip-dna", action="store_true", help="if set, skip the chromosome embedding step") #if glm embeddings successful but not plm embeddings - p.add_argument("--tf-fasta", required=True, help="input TF FASTA file") - p.add_argument("--chrom-model", default="caduceus") - p.add_argument("--tf-model", default="esm-dbp") - p.add_argument("--out-dir", default="data_files/processed/tfclust/hg38_tf/embeddings") - p.add_argument("--device", default="cpu") + p.add_argument( + "--peak-fasta", + default="binding_peaks_unique.fa", + help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs", + ) + p.add_argument( + "--genome-json-dir", + default=None, + help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes", + ) + p.add_argument( + "--skip-dna", + action="store_true", + help="if set, skip the chromosome embedding step", + ) # if glm embeddings successful but not plm embeddings + p.add_argument("--tf-fasta", required=True, help="input TF FASTA file") + p.add_argument("--chrom-model", default="caduceus") + p.add_argument("--tf-model", default="esm-dbp") + p.add_argument( + "--out-dir", default="data_files/processed/tfclust/hg38_tf/embeddings" + ) + p.add_argument("--device", default="cpu") args = p.parse_args() os.makedirs(args.out_dir, exist_ok=True) @@ -495,7 +553,10 @@ if __name__=="__main__": for rec in SeqIO.parse(peak_fasta, "fasta"): peak_ids.append(rec.id) peak_seqs.append(str(rec.seq)) - print(f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}", flush=True) + print( + f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}", + flush=True, + ) dna_embedder = get_embedder(args.chrom_model, device, for_dna=True) out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy" embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks) @@ -503,7 +564,9 @@ if __name__=="__main__": # Legacy: load full chromosomes from JSONs (chr1–22, X, Y, M) genome_dir = Path(args.genome_json_dir) chrom_seqs, chrom_ids = [], [] - primary_pattern = re.compile(r"^hg38_chr(?:[1-9]|1[0-9]|2[0-2]|X|Y|M)\.json$") + primary_pattern = re.compile( + r"^hg38_chr(?:[1-9]|1[0-9]|2[0-2]|X|Y|M)\.json$" + ) for j in sorted(genome_dir.iterdir()): if not primary_pattern.match(j.name): continue @@ -519,7 +582,9 @@ if __name__=="__main__": if len(seq) > cutoff ] if long_chroms: - print("⚠️ Chromosomes exceeding Caduceus max tokens ({}):".format(cutoff)) + print( + "⚠️ Chromosomes exceeding Caduceus max tokens ({}):".format(cutoff) + ) for chrom, L in long_chroms: print(f" {chrom}: {L} bases") else: @@ -529,10 +594,11 @@ if __name__=="__main__": out_chrom = Path(args.out_dir) / f"chrom_{args.chrom_model}.npy" embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom) else: - raise ValueError("No input for DNA embedding: provide a peak FASTA (default binding_peaks_unique.fa) or set --genome-json-dir for chromosome JSONs.") - + raise ValueError( + "No input for DNA embedding: provide a peak FASTA (default binding_peaks_unique.fa) or set --genome-json-dir for chromosome JSONs." + ) - #Load TF sequences + # Load TF sequences tf_seqs, tf_ids = [], [] for record in SeqIO.parse(args.tf_fasta, "fasta"): tf_ids.append(record.id) @@ -543,4 +609,4 @@ if __name__=="__main__": out_tf = Path(args.out_dir) / f"tf_{args.tf_model}.npy" embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf) - print("Done.") \ No newline at end of file + print("Done.") diff --git a/dpacman/classifier/model_tmp/extract_tf_symbols.py b/dpacman/classifier/model_tmp/extract_tf_symbols.py index 3c833d8454b71fdd7b05088f73a44605a31b95f0..c2ef113215fbc07a7cc6e399599c590c40499061 100644 --- a/dpacman/classifier/model_tmp/extract_tf_symbols.py +++ b/dpacman/classifier/model_tmp/extract_tf_symbols.py @@ -5,9 +5,11 @@ from pathlib import Path FINAL_CSV = Path("/home/a03-akrishna/DPACMAN/data_files/processed/final.csv") OUT_SYMBOLS = Path("tf_symbols.txt") + def normalize_tf(tf_id: str) -> str: return tf_id.split("_seq")[0].upper() + def main(): df = pd.read_csv(FINAL_CSV, dtype=str) if "TF_id" not in df.columns: @@ -23,5 +25,6 @@ def main(): # Optional: show sample print("Sample symbols:", normalized[:50]) + if __name__ == "__main__": main() diff --git a/dpacman/classifier/model_tmp/make_pair_list.py b/dpacman/classifier/model_tmp/make_pair_list.py index 03d429b7585352d746e95811a0dfd73ef0a9331c..09ac5cf936af83a3dc12c31b3d0a695b0f7faeab 100644 --- a/dpacman/classifier/model_tmp/make_pair_list.py +++ b/dpacman/classifier/model_tmp/make_pair_list.py @@ -6,12 +6,14 @@ from pathlib import Path import random import sys + def read_ids_file(p): p = Path(p) if not p.exists(): raise FileNotFoundError(f"IDs file not found: {p}") return [line.strip() for line in p.open() if line.strip()] + def split_embeddings(emb_path, ids_path, out_dir, prefix): out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) @@ -32,12 +34,17 @@ def split_embeddings(emb_path, ids_path, out_dir, prefix): ids = read_ids_file(ids_path) if len(ids) != emb.shape[0]: - print(f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", file=sys.stderr) + print( + f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", + file=sys.stderr, + ) mapping = {} for i, ident in enumerate(ids): if i >= emb.shape[0]: - print(f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr) + print( + f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr + ) continue arr = emb[i] out_file = out_dir / f"{prefix}_{ident}.npy" @@ -45,6 +52,7 @@ def split_embeddings(emb_path, ids_path, out_dir, prefix): mapping[ident] = str(out_file) return mapping + def extract_symbol_from_tf_id(full_id: str) -> str: """ Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN', @@ -61,6 +69,7 @@ def extract_symbol_from_tf_id(full_id: str) -> str: symbol = genepart.split("_")[0] return symbol.upper() + def build_tf_symbol_map(tf_map): """ Build mapping gene_symbol -> list of embedding paths. @@ -71,6 +80,7 @@ def build_tf_symbol_map(tf_map): symbol_map.setdefault(symbol, []).append(path) return symbol_map + def tf_key_from_path(path: str) -> str: """ Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'. @@ -83,6 +93,7 @@ def tf_key_from_path(path: str) -> str: rest = stem return extract_symbol_from_tf_id(rest) + def dna_key_from_path(path: str) -> str: """ Given .../dna_peak42.npy -> 'peak42' @@ -94,17 +105,35 @@ def dna_key_from_path(path: str) -> str: rest = stem return rest + def main(): parser = argparse.ArgumentParser( description="Build TF-DNA pair list from final.csv with gene-symbol normalization for TFs." ) - parser.add_argument("--final_csv", required=True, help="final.csv with TF_id and dna_sequence") - parser.add_argument("--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)") - parser.add_argument("--dna_ids", required=True, help="IDs file for DNA embeddings (e.g., peak*.ids)") - parser.add_argument("--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)") - parser.add_argument("--tf_ids", required=True, help="IDs file for TF embeddings (e.g., sp|...|... ids)") + parser.add_argument( + "--final_csv", required=True, help="final.csv with TF_id and dna_sequence" + ) + parser.add_argument( + "--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)" + ) + parser.add_argument( + "--dna_ids", required=True, help="IDs file for DNA embeddings (e.g., peak*.ids)" + ) + parser.add_argument( + "--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)" + ) + parser.add_argument( + "--tf_ids", + required=True, + help="IDs file for TF embeddings (e.g., sp|...|... ids)", + ) parser.add_argument("--out_dir", required=True, help="Output directory") - parser.add_argument("--neg_per_positive", type=int, default=2, help="Negatives per positive (half same-TF, half same-DNA)") + parser.add_argument( + "--neg_per_positive", + type=int, + default=2, + help="Negatives per positive (half same-TF, half same-DNA)", + ) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() @@ -126,12 +155,24 @@ def main(): print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}") # Split embeddings into per-item files - print(f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}") - dna_map = split_embeddings(args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna") - print(f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})") - print(f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}") - tf_map = split_embeddings(args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf") - print(f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})") + print( + f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}" + ) + dna_map = split_embeddings( + args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna" + ) + print( + f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})" + ) + print( + f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}" + ) + tf_map = split_embeddings( + args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf" + ) + print( + f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})" + ) # Build gene-symbol normalized map tf_symbol_map = build_tf_symbol_map(tf_map) @@ -145,15 +186,28 @@ def main(): print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}") print(f"[i] Intersection count: {len(intersect_tf)}") if len(intersect_tf) == 0: - print("[ERROR] No overlap between normalized TF_id and TF embedding symbols.", file=sys.stderr) - print("Sample normalized TFs from final.csv:", sorted(list(norm_tf_in_final))[:30], file=sys.stderr) - print("Sample available TF symbols:", sorted(list(available_tf_symbols))[:30], file=sys.stderr) + print( + "[ERROR] No overlap between normalized TF_id and TF embedding symbols.", + file=sys.stderr, + ) + print( + "Sample normalized TFs from final.csv:", + sorted(list(norm_tf_in_final))[:30], + file=sys.stderr, + ) + print( + "Sample available TF symbols:", + sorted(list(available_tf_symbols))[:30], + file=sys.stderr, + ) sys.exit(1) dna_ids_final = set(df["dna_id"].unique()) available_dna_ids = set(dna_map.keys()) intersect_dna = dna_ids_final & available_dna_ids - print(f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}") + print( + f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}" + ) if len(intersect_dna) == 0: print("[ERROR] No overlap on DNA ids.", file=sys.stderr) sys.exit(1) @@ -174,7 +228,9 @@ def main(): print(f"[i] Constructed {len(positives)} positive pairs after TF symbol resolution") if len(positives) == 0: - print("[ERROR] No positive pairs could be constructed; aborting.", file=sys.stderr) + print( + "[ERROR] No positive pairs could be constructed; aborting.", file=sys.stderr + ) sys.exit(1) # Build negative samples @@ -200,13 +256,18 @@ def main(): # same DNA, different TF for _ in range(half): candidate_tf_symbol = random.choice(all_tf_symbols) - if candidate_tf_symbol == tf_key or (candidate_tf_symbol, dna_key) in positive_set: + if ( + candidate_tf_symbol == tf_key + or (candidate_tf_symbol, dna_key) in positive_set + ): continue # pick its first embedding candidate_tf_path = tf_symbol_map[candidate_tf_symbol][0] negatives.append((candidate_tf_path, dna_map[dnaid], 0)) - print(f"[i] Sampled {len(negatives)} negatives (neg_per_positive={args.neg_per_positive})") + print( + f"[i] Sampled {len(negatives)} negatives (neg_per_positive={args.neg_per_positive})" + ) # Write pair list pair_list_path = out_dir / "pair_list.tsv" @@ -216,5 +277,6 @@ def main(): f.write(f"{binder_path}\t{glm_path}\t{label}\n") print(f"[i] Wrote {len(positives)+len(negatives)} examples to {pair_list_path}") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/dpacman/classifier/model_tmp/make_peak_fasta.py b/dpacman/classifier/model_tmp/make_peak_fasta.py index 9c7000d1e7b8f2170b703d364085e4f8f899b0d2..5babb216bcc0fdf1d26aba6695e3c613b61a221b 100644 --- a/dpacman/classifier/model_tmp/make_peak_fasta.py +++ b/dpacman/classifier/model_tmp/make_peak_fasta.py @@ -1,7 +1,9 @@ import pandas as pd from pathlib import Path -df = pd.read_csv("/home/a03-akrishna/DPACMAN/data_files/processed/final.csv", dtype=str) # adjust path if needed +df = pd.read_csv( + "/home/a03-akrishna/DPACMAN/data_files/processed/final.csv", dtype=str +) # adjust path if needed # get unique sequences uniq = df[["dna_sequence"]].drop_duplicates().reset_index(drop=True) # make headers: e.g., peak0, peak1, ... @@ -10,4 +12,4 @@ with open(out_fa, "w") as f: for i, seq in enumerate(uniq["dna_sequence"]): header = f">peak{i}" f.write(f"{header}\n{seq}\n") -print(f"Wrote {len(uniq)} unique binding sequences to {out_fa}") \ No newline at end of file +print(f"Wrote {len(uniq)} unique binding sequences to {out_fa}") diff --git a/dpacman/classifier/model_tmp/model.py b/dpacman/classifier/model_tmp/model.py index 7288c2d5df020ba6347e6068a024d3f70e8b497e..6da6c208729bb52a057efeadf1ae2c252268683a 100644 --- a/dpacman/classifier/model_tmp/model.py +++ b/dpacman/classifier/model_tmp/model.py @@ -1,6 +1,7 @@ import torch from torch import nn + class LocalCNN(nn.Module): def __init__(self, dim: int = 256, kernel_size: int = 3): super().__init__() @@ -13,8 +14,9 @@ class LocalCNN(nn.Module): # x: (batch, L, dim) out = self.conv(x.transpose(1, 2)) # → (batch, dim, L) out = self.act(out) - out = out.transpose(1, 2) # → (batch, L, dim) - return self.ln(out + x) # residual + out = out.transpose(1, 2) # → (batch, L, dim) + return self.ln(out + x) # residual + class CrossModalBlock(nn.Module): def __init__(self, dim: int = 256, heads: int = 8): @@ -25,15 +27,21 @@ class CrossModalBlock(nn.Module): self.ln_b1 = nn.LayerNorm(dim) self.ln_g1 = nn.LayerNorm(dim) - self.ffn_b = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) - self.ffn_g = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) + self.ffn_b = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim) + ) + self.ffn_g = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim) + ) self.ln_b2 = nn.LayerNorm(dim) self.ln_g2 = nn.LayerNorm(dim) # cross attention (binder queries, glm keys/values) self.cross_attn = nn.MultiheadAttention(dim, heads, batch_first=True) self.ln_c1 = nn.LayerNorm(dim) - self.ffn_c = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) + self.ffn_c = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim) + ) self.ln_c2 = nn.LayerNorm(dim) def forward(self, binder: torch.Tensor, glm: torch.Tensor): @@ -63,41 +71,41 @@ class CrossModalBlock(nn.Module): c = self.ln_c2(c + c_ff) return c # (batch, Lb, dim) + class BindPredictor(nn.Module): - def __init__(self, - input_dim: int = 256, - hidden_dim: int = 256, - heads: int = 8, - num_layers: int = 4, - use_local_cnn_on_glm: bool = True): + def __init__( + self, + input_dim: int = 256, + hidden_dim: int = 256, + heads: int = 8, + num_layers: int = 4, + use_local_cnn_on_glm: bool = True, + ): super().__init__() self.proj_binder = nn.Linear(input_dim, hidden_dim) self.proj_glm = nn.Linear(input_dim, hidden_dim) self.use_local_cnn = use_local_cnn_on_glm self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity() - self.layers = nn.ModuleList([ - CrossModalBlock(hidden_dim, heads) for _ in range(num_layers) - ]) + self.layers = nn.ModuleList( + [CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)] + ) self.ln_out = nn.LayerNorm(hidden_dim) - self.head = nn.Sequential( - nn.Linear(hidden_dim, 1), - nn.Sigmoid() - ) + self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) def forward(self, binder_emb, glm_emb): """ binder_emb, glm_emb: (batch, L, input_dim) """ - b = self.proj_binder(binder_emb) # (B, Lb, hidden_dim) - g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim) + b = self.proj_binder(binder_emb) # (B, Lb, hidden_dim) + g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim) if self.use_local_cnn: - g = self.local_cnn(g) # local context injected + g = self.local_cnn(g) # local context injected for layer in self.layers: - b = layer(b, g) # update binder with cross-modal info + b = layer(b, g) # update binder with cross-modal info - pooled = b.mean(dim=1) # (B, hidden_dim) + pooled = b.mean(dim=1) # (B, hidden_dim) out = self.ln_out(pooled) - return self.head(out).squeeze(-1) # (B,) + return self.head(out).squeeze(-1) # (B,) diff --git a/dpacman/classifier/model_tmp/prep_splits.py b/dpacman/classifier/model_tmp/prep_splits.py index 49bed7765d21faea29a77e46fc8c2d5bc8b46c17..f67681d4c2981550f2aa12a4aed6e34ea7080ae8 100644 --- a/dpacman/classifier/model_tmp/prep_splits.py +++ b/dpacman/classifier/model_tmp/prep_splits.py @@ -6,6 +6,7 @@ from sklearn.decomposition import TruncatedSVD from sklearn.model_selection import train_test_split from collections import Counter + def parse_pair_list(pair_list_path): binder_paths, glm_paths, labels = [], [], [] with open(pair_list_path) as f: @@ -14,7 +15,10 @@ def parse_pair_list(pair_list_path): continue parts = line.strip().split() if len(parts) != 3: - print(f"[WARN] skipping malformed line {lineno}: {line.strip()}", file=sys.stderr) + print( + f"[WARN] skipping malformed line {lineno}: {line.strip()}", + file=sys.stderr, + ) continue b, g, l = parts try: @@ -27,12 +31,16 @@ def parse_pair_list(pair_list_path): labels.append(lab) return binder_paths, glm_paths, labels + def build_tf_compressed_cache(binder_paths, target_dim=256): """ Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array. """ unique_paths = sorted(set(binder_paths)) - print(f"[i] Found {len(unique_paths)} unique TF embedding files to compress.", flush=True) + print( + f"[i] Found {len(unique_paths)} unique TF embedding files to compress.", + flush=True, + ) # Load all embeddings to determine dimensionality samples = [] for p in unique_paths: @@ -41,7 +49,7 @@ def build_tf_compressed_cache(binder_paths, target_dim=256): # Determine if reduction needed: assume all have same embedding width first = samples[0] orig_dim = first.shape[1] if first.ndim == 2 else 1 - reduction_needed = (orig_dim != target_dim) + reduction_needed = orig_dim != target_dim tf_cache = {} if reduction_needed: @@ -57,7 +65,10 @@ def build_tf_compressed_cache(binder_paths, target_dim=256): else: pooled.append(arr) # degenerate pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim) - print(f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}", flush=True) + print( + f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}", + flush=True, + ) svd = TruncatedSVD(n_components=target_dim, random_state=42) reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim) @@ -75,27 +86,36 @@ def build_tf_compressed_cache(binder_paths, target_dim=256): print("[i] Completed compression of TF embeddings.", flush=True) else: # already correct dim: just cache originals - print(f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.", flush=True) + print( + f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.", + flush=True, + ) for i, p in enumerate(unique_paths): arr = samples[i] tf_cache[p] = arr return tf_cache + def main(): - #df = pd.read_csv("../data_files/processed/fimo/ananya_aug4_2025_final.csv") - - binder_paths, glm_paths, labels = parse_pair_list("../data_files/processed/fimo/ananya_aug4_2025_pair_list.tsv") + # df = pd.read_csv("../data_files/processed/fimo/ananya_aug4_2025_final.csv") + + binder_paths, glm_paths, labels = parse_pair_list( + "../data_files/processed/fimo/ananya_aug4_2025_pair_list.tsv" + ) if len(labels) == 0: print("[ERROR] No valid pairs parsed. Exiting.", file=sys.stderr) sys.exit(1) label_counts = Counter(labels) - print(f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}", flush=True) + print( + f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}", + flush=True, + ) # build compressed TF cache (reduces to 256 if needed) - #tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256) - + # tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256) + # Combine all data into one structure for easy splitting data = list(zip(binder_paths, glm_paths, labels)) @@ -109,14 +129,15 @@ def main(): def unpack(data): binders, glms, labels = zip(*data) return list(binders), list(glms), list(labels) - def save_split(binder_paths, glm_paths, labels, out_path): - df = pd.DataFrame({ - "binder_path": binder_paths, - "glm_path": glm_paths, - "label": labels, - }) + df = pd.DataFrame( + { + "binder_path": binder_paths, + "glm_path": glm_paths, + "label": labels, + } + ) df.to_csv(out_path, index=False) # Unpack data for saving @@ -125,9 +146,12 @@ def main(): test_binders, test_glms, test_labels = unpack(test_data) # Save each split - save_split(train_binders, train_glms, train_labels, "../data_files/splits/train.csv") + save_split( + train_binders, train_glms, train_labels, "../data_files/splits/train.csv" + ) save_split(val_binders, val_glms, val_labels, "../data_files/splits/val.csv") save_split(test_binders, test_glms, test_labels, "../data_files/splits/test.csv") - -if __name__=="__main__": - main() \ No newline at end of file + + +if __name__ == "__main__": + main() diff --git a/dpacman/classifier/model_tmp/train.py b/dpacman/classifier/model_tmp/train.py index 11ee6cadd2aef08aa1fd54dc98e994bca7420711..05581aae5467bbed0ec3200ff38c18fb7741897c 100644 --- a/dpacman/classifier/model_tmp/train.py +++ b/dpacman/classifier/model_tmp/train.py @@ -12,12 +12,16 @@ import sys from dpacman.utils.models import set_seed + def build_tf_compressed_cache(binder_paths, target_dim=256): """ Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array. """ unique_paths = sorted(set(binder_paths)) - print(f"[i] Found {len(unique_paths)} unique TF embedding files to compress.", flush=True) + print( + f"[i] Found {len(unique_paths)} unique TF embedding files to compress.", + flush=True, + ) # Load all embeddings to determine dimensionality samples = [] for p in unique_paths: @@ -26,7 +30,7 @@ def build_tf_compressed_cache(binder_paths, target_dim=256): # Determine if reduction needed: assume all have same embedding width first = samples[0] orig_dim = first.shape[1] if first.ndim == 2 else 1 - reduction_needed = (orig_dim != target_dim) + reduction_needed = orig_dim != target_dim tf_cache = {} if reduction_needed: @@ -42,7 +46,10 @@ def build_tf_compressed_cache(binder_paths, target_dim=256): else: pooled.append(arr) # degenerate pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim) - print(f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}", flush=True) + print( + f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}", + flush=True, + ) svd = TruncatedSVD(n_components=target_dim, random_state=42) reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim) @@ -60,12 +67,16 @@ def build_tf_compressed_cache(binder_paths, target_dim=256): print("[i] Completed compression of TF embeddings.", flush=True) else: # already correct dim: just cache originals - print(f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.", flush=True) + print( + f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.", + flush=True, + ) for i, p in enumerate(unique_paths): arr = samples[i] tf_cache[p] = arr return tf_cache + def evaluate(model, dl, device): model.eval() all_labels = [] @@ -92,13 +103,15 @@ def evaluate(model, dl, device): ap = 0.0 return auc, ap + def unpack(data): binders, glms, labels = zip(*data) return list(binders), list(glms), list(labels) + # ---- main ------------------------------------------------------------ def main(cfg): - # Set seed for reproducibility + # Set seed for reproducibility set_seed(cfg.seed) parser.add_argument("--out_dir", type=str, required=True) @@ -109,7 +122,7 @@ def main(cfg): parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() - # + # print("DEBUG: starting training script with in-line TF compression", flush=True) device = torch.device(args.device if torch.cuda.is_available() else "cpu") binder_paths, glm_paths, labels = parse_pair_list(cfg.pair_list) @@ -119,7 +132,10 @@ def main(cfg): sys.exit(1) label_counts = Counter(labels) - print(f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}", flush=True) + print( + f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}", + flush=True, + ) # build compressed TF cache (reduces to 256 if needed) tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256) @@ -130,16 +146,30 @@ def main(cfg): val_ds = PairDataset(*subset(val_i), tf_compressed_cache=tf_compressed_cache) test_ds = PairDataset(*subset(test_i), tf_compressed_cache=tf_compressed_cache) - print(f"[i] Train/Val/Test sizes: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}", flush=True) + print( + f"[i] Train/Val/Test sizes: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}", + flush=True, + ) if len(train_ds) == 0 or len(val_ds) == 0: - print("[ERROR] Train or validation split is empty; cannot proceed.", file=sys.stderr) + print( + "[ERROR] Train or validation split is empty; cannot proceed.", + file=sys.stderr, + ) sys.exit(1) - train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn) - val_dl = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) - test_dl = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) - - model = BindPredictor(input_dim=256, hidden_dim=256, heads=8, num_layers=3, use_local_cnn_on_glm=True) + train_dl = DataLoader( + train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn + ) + val_dl = DataLoader( + val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn + ) + test_dl = DataLoader( + test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn + ) + + model = BindPredictor( + input_dim=256, hidden_dim=256, heads=8, num_layers=3, use_local_cnn_on_glm=True + ) model = model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-3) loss_fn = nn.BCELoss() @@ -164,17 +194,24 @@ def main(cfg): running_loss += loss.item() * b.size(0) train_loss = running_loss / len(train_ds) val_auc, val_ap = evaluate(model, val_dl, device) - print(f"[Epoch {epoch}] train_loss={train_loss:.4f} val_auc={val_auc:.4f} val_ap={val_ap:.4f}", flush=True) + print( + f"[Epoch {epoch}] train_loss={train_loss:.4f} val_auc={val_auc:.4f} val_ap={val_ap:.4f}", + flush=True, + ) if val_auc > best_val: best_val = val_auc torch.save(model.state_dict(), os_out / "best_model.pt") - print(f"[Epoch {epoch}] Saved new best model with val_auc={val_auc:.4f}", flush=True) + print( + f"[Epoch {epoch}] Saved new best model with val_auc={val_auc:.4f}", + flush=True, + ) torch.save(model.state_dict(), os_out / "last_model.pt") test_auc, test_ap = evaluate(model, test_dl, device) print(f"FINAL TEST: AUC={test_auc:.4f} AP={test_ap:.4f}", flush=True) print(f"[i] Models written to {os_out}/best_model.pt and last_model.pt", flush=True) + if __name__ == "__main__": main() diff --git a/dpacman/classifier/old_train.py b/dpacman/classifier/old_train.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb745d5f707ba85b9916e45b75977d4262281b8 --- /dev/null +++ b/dpacman/classifier/old_train.py @@ -0,0 +1,486 @@ +import argparse, random, sys +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from torch import nn +from torch.utils.data import Dataset, DataLoader, Sampler + +# from sklearn.random_projection import GaussianRandomProjection # OLD (kept): projection was removed earlier +import matplotlib.pyplot as plt + +import torch.amp as amp +from torch.nn import functional as F +from model import BindPredictor + + +# ─────────────── utilities ──────────────────────────────────────────────── +def parse_pair_list(path): + binders, glms = [], [] + with open(path) as f: + for ln, line in enumerate(f, 1): + parts = line.strip().split() + if len(parts) < 2: + continue + b, g = parts[0], parts[1] + binders.append(b) + glms.append(g) + return binders, glms + + +class ListBatchSampler(Sampler): + def __init__(self, batches): + self.batches = batches + + def __iter__(self): + return iter(self.batches) + + def __len__(self): + return len(self.batches) + + +def make_buckets(idxs, glm_paths, batch_size, n_buckets=10, seed=42): + rng = random.Random(seed) + lengths = [(i, np.load(glm_paths[i]).shape[0]) for i in idxs] + lengths.sort(key=lambda x: x[1]) + size = max(1, int(np.ceil(len(lengths) / n_buckets))) + buckets = [lengths[i : i + size] for i in range(0, len(lengths), size)] + batches = [] + for bucket in buckets: + ids = [i for i, _ in bucket] + rng.shuffle(ids) + for i in range(0, len(ids), batch_size): + batches.append(ids[i : i + batch_size]) + rng.shuffle(batches) + return batches + + +def dna_key_from_path(path: str) -> str: + """.../dna_peak42.npy -> 'peak42'""" + stem = Path(path).stem + if "_" in stem: + _, rest = stem.split("_", 1) + else: + rest = stem + return rest + + +def build_tf_cache(tf_paths, target_dim=256): + """ + Load raw TF embeddings without projecting; compression is learnable in the model. + """ + unique = sorted(set(tf_paths)) + print( + f"[i] (Learnable) Preparing {len(unique)} TF files; target {target_dim}d inside the model", + flush=True, + ) + + pools, raw = [], [] + for p in unique: + arr = np.load(p) # (L, D) or (D,) + raw.append(arr) + pools.append(arr.mean(axis=0) if arr.ndim == 2 else arr) + M = np.stack(pools, 0) + orig_dim = M.shape[1] + print(f"[i] Pooled shape → {M.shape} (orig_dim={orig_dim})", flush=True) + + cache = {} + for i, p in enumerate(unique): + arr = raw[i] + # OLD: projection here (removed) + cache[p] = arr + print("[i] TF cache ready (raw); compression will be learned.", flush=True) + return cache + + +# ─────────────── Dataset & Collation ───────────────────────────────────── +class PairDataset(Dataset): + def __init__(self, tf_paths, dna_paths, final_df, tf_cache): + self.tf_paths, self.dna_paths = tf_paths, dna_paths + self.tf_cache = tf_cache + self.targets = {} + for _, row in final_df.iterrows(): + dna_id = row["dna_id"] + vec = np.array( + list(map(float, row["score_sig_r2"].split(","))), dtype=np.float32 + ) + self.targets[dna_id] = vec + + def __len__(self): + return len(self.tf_paths) + + def __getitem__(self, i): + b = self.tf_cache[self.tf_paths[i]] # (L_b, D_b) or (D_b,) + if b.ndim == 1: + b = b[None, :] + g = np.load(self.dna_paths[i]) # (L_g, 256) or (256,) + if g.ndim == 1: + g = g[None, :] + + stem = Path(self.dna_paths[i]).stem + dna_id = stem.replace("dna_", "") + t = self.targets.get(dna_id, np.zeros(g.shape[0], dtype=np.float32)) + + return ( + torch.from_numpy(b).float(), + torch.from_numpy(g).float(), + torch.from_numpy(t).float(), + ) + + +def collate_fn(batch): + Bs = [b.shape[0] for b, _, _ in batch] + Gs = [g.shape[0] for _, g, _ in batch] + maxB, maxG = max(Bs), max(Gs) + + def pad_seq(x, L): + if x.shape[0] < L: + pad = torch.zeros( + (L - x.shape[0], x.shape[1]), dtype=x.dtype, device=x.device + ) + return torch.cat([x, pad], dim=0) + return x + + def pad_t(y, L): + if y.shape[0] < L: + pad = torch.zeros((L - y.shape[0],), dtype=y.dtype, device=y.device) + return torch.cat([y, pad], dim=0) + return y + + b_stack = torch.stack([pad_seq(b, maxB) for b, _, _ in batch]) + g_stack = torch.stack([pad_seq(g, maxG) for _, g, _ in batch]) + t_stack = torch.stack([pad_t(t, maxG) for *_, t in batch]) + return b_stack, g_stack, t_stack + + +# ─────────────── losses, metrics ───────────────────────────────────────── +def combined_loss_components(logits, targets, peak_thresh=0.5, eps=1e-8): + probs = torch.sigmoid(logits) + labels = (targets >= peak_thresh).float() + non_peak_mask = (labels == 0).float() + peak_mask = (labels == 1).float() + + bce_all = F.binary_cross_entropy_with_logits(logits, labels, reduction="none") + bce_non = bce_all * non_peak_mask + bce_non = bce_non.sum() / (non_peak_mask.sum() + eps) + + mse_peaks = F.mse_loss(probs * peak_mask, targets * peak_mask, reduction="sum") / ( + peak_mask.sum() + eps + ) + mse_global = F.mse_loss(probs, targets, reduction="mean") + + t_dist = targets + eps + p_dist = probs + eps + t_dist = t_dist / t_dist.sum(dim=1, keepdim=True) + p_dist = p_dist / p_dist.sum(dim=1, keepdim=True) + kl = ( + (t_dist * (t_dist.clamp(min=eps).log() - p_dist.clamp(min=eps).log())) + .sum(dim=1) + .mean() + ) + + return bce_non, kl, mse_global, probs + + +def accuracy_percentage(logits, targets, peak_thresh=0.5): + probs = torch.sigmoid(logits) + preds_bin = (probs >= 0.5).float() + labels = (targets >= peak_thresh).float() + correct = (preds_bin == labels).float().sum() + total = torch.numel(labels) + return (correct / max(1, total)).item() * 100.0 + + +def evaluate(model, dl, device, alpha, beta, gamma, peak_thresh, eps=1e-8): + model.eval() + tot_loss, tot_acc = 0.0, 0.0 + n_batches = 0 + with torch.no_grad(): + for b, g, t in dl: + b, g, t = b.to(device), g.to(device), t.to(device) + logits = model(b, g) + bce_non, kl, mse_global, _ = combined_loss_components( + logits, t, peak_thresh=peak_thresh, eps=eps + ) + loss = alpha * bce_non + beta * kl + gamma * mse_global + acc = accuracy_percentage(logits, t, peak_thresh=peak_thresh) + tot_loss += loss.item() + tot_acc += acc + n_batches += 1 + if n_batches == 0: + return float("nan"), float("nan") + return tot_loss / n_batches, tot_acc / n_batches + + +# ─────────────── cluster-aware splitting ────────────────────────────────── +def assign_clusters_to_splits( + cluster_to_indices, val_frac=0.10, test_frac=0.10, seed=42 +): + """ + cluster_to_indices: dict[cluster_id] -> list of example indices (from pair_list) in that cluster + We greedily pack whole clusters into val/test until hitting targets (#examples), rest to train. + """ + rng = random.Random(seed) + clusters = list(cluster_to_indices.items()) + rng.shuffle(clusters) + + total = sum(len(ixs) for _, ixs in clusters) + target_val = int(round(total * val_frac)) + target_test = int(round(total * test_frac)) + cur_val = cur_test = 0 + + tr_ix, va_ix, te_ix = [], [], [] + for cid, ixs in clusters: + c = len(ixs) + if cur_val + c <= target_val: + va_ix.extend(ixs) + cur_val += c + elif cur_test + c <= target_test: + te_ix.extend(ixs) + cur_test += c + else: + tr_ix.extend(ixs) + return tr_ix, va_ix, te_ix + + +# ─────────────── train & main ──────────────────────────────────────────── +def main(): + p = argparse.ArgumentParser() + p.add_argument("--pair_list", required=True) + p.add_argument("--final_csv", required=True) + p.add_argument("--out_dir", required=True) + p.add_argument("--epochs", type=int, default=10) + p.add_argument("--batch_size", type=int, default=16) + p.add_argument("--accum_steps", type=int, default=4) + p.add_argument("--lr", type=float, default=1e-4) + p.add_argument("--device", default="cuda") + p.add_argument("--seed", type=int, default=42) + p.add_argument("--alpha", type=float, default=1) + p.add_argument("--beta", type=float, default=0) + p.add_argument("--gamma", type=float, default=1) + p.add_argument("--peak_thresh", type=float, default=0.5) + # NEW: fractions for cluster-aware split (used only if cluster_id present) + p.add_argument("--val_frac", type=float, default=0.10) + p.add_argument("--test_frac", type=float, default=0.10) + args = p.parse_args() + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + + # 1) load pair list & final.csv (now may include cluster_id) + tf_paths, dna_paths = parse_pair_list(args.pair_list) + final_df = pd.read_csv(args.final_csv, dtype=str) + print(f"[i] Loaded {len(tf_paths)} pairs", flush=True) + + tf_cache = build_tf_cache(tf_paths, target_dim=256) + + # detect binder/DNA dims + sample_tf = tf_cache[tf_paths[0]] + binder_input_dim = sample_tf.shape[1] if sample_tf.ndim == 2 else sample_tf.shape[0] + glm_input_dim = 256 + + # 2) cluster-aware split if possible + use_cluster_split = "cluster_id" in final_df.columns + if use_cluster_split: + print( + "[i] Cluster column detected in final_csv; performing cluster-aware split.", + flush=True, + ) + # build dna_id -> cluster_id map + cid_map = ( + final_df[["dna_id", "cluster_id"]] + .dropna() + .drop_duplicates() + .set_index("dna_id")["cluster_id"] + .to_dict() + ) + + # map each example (by index) to its dna_id and cluster + example_dna_ids = [dna_key_from_path(p) for p in dna_paths] + example_clusters = [] + missing = 0 + for did in example_dna_ids: + if did in cid_map: + example_clusters.append(cid_map[did]) + else: + # fallback: treat singleton cluster + example_clusters.append(f"singleton::{did}") + missing += 1 + if missing: + print( + f"[WARN] {missing} dna_ids from pair_list not found in cluster map; treating as singleton clusters.", + flush=True, + ) + + # build cluster -> indices + cluster_to_indices = {} + for i, cid in enumerate(example_clusters): + cluster_to_indices.setdefault(cid, []).append(i) + + tr_idx, va_idx, te_idx = assign_clusters_to_splits( + cluster_to_indices, + val_frac=args.val_frac, + test_frac=args.test_frac, + seed=args.seed, + ) + print( + f"[i] Cluster split sizes (examples): train={len(tr_idx)} val={len(va_idx)} test={len(te_idx)}", + flush=True, + ) + + # helper to subset paths + def subset_by_indices(ixs): + return [tf_paths[i] for i in ixs], [dna_paths[i] for i in ixs] + + tr_t, tr_d = subset_by_indices(tr_idx) + va_t, va_d = subset_by_indices(va_idx) + te_t, te_d = subset_by_indices(te_idx) + + else: + print( + "[i] No cluster_id in final_csv; using random 80/10/10 split (OLD behavior).", + flush=True, + ) + # OLD random split (kept, now under else) + N = len(tf_paths) + idxs = list(range(N)) + random.shuffle(idxs) + n_tr = int(0.8 * N) + n_va = int(0.1 * N) + tr, va, te = idxs[:n_tr], idxs[n_tr : n_tr + n_va], idxs[n_tr + n_va :] + + def subset(idxs_): + return [tf_paths[i] for i in idxs_], [dna_paths[i] for i in idxs_] + + tr_t, tr_d = subset(tr) + va_t, va_d = subset(va) + te_t, te_d = subset(te) + + # 3) bucketed samplers (unchanged, but now use the cluster-aware subsets when available) + tr_bs = make_buckets( + list(range(len(tr_t))), tr_d, args.batch_size, n_buckets=10, seed=args.seed + ) + va_bs = make_buckets( + list(range(len(va_t))), va_d, args.batch_size, n_buckets=5, seed=args.seed + 1 + ) + te_bs = make_buckets( + list(range(len(te_t))), te_d, args.batch_size, n_buckets=5, seed=args.seed + 2 + ) + + tr_dl = DataLoader( + PairDataset(tr_t, tr_d, final_df, tf_cache), + batch_sampler=ListBatchSampler(tr_bs), + collate_fn=collate_fn, + ) + va_dl = DataLoader( + PairDataset(va_t, va_d, final_df, tf_cache), + batch_sampler=ListBatchSampler(va_bs), + collate_fn=collate_fn, + ) + te_dl = DataLoader( + PairDataset(te_t, te_d, final_df, tf_cache), + batch_sampler=ListBatchSampler(te_bs), + collate_fn=collate_fn, + ) + + # 4) model, optimizer, scaler + model = BindPredictor( + binder_input_dim=binder_input_dim, + glm_input_dim=glm_input_dim, + compressed_dim=256, + hidden_dim=256, + heads=8, + num_layers=4, + use_local_cnn_on_glm=True, + ).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) + scaler = amp.GradScaler("cuda") + + history, best_val = {"train": [], "val": []}, float("inf") + od = Path(args.out_dir) + od.mkdir(exist_ok=True, parents=True) + + for ep in range(1, args.epochs + 1): + print(f"┌─[Epoch {ep}]────────────────────────", flush=True) + model.train() + optimizer.zero_grad() + acc_loss_sum, acc_acc_sum, n_train_batches = 0.0, 0.0, 0 + + for i, (b, g, t) in enumerate(tr_dl): + b, g, t = b.to(device), g.to(device), t.to(device) + with amp.autocast("cuda"): + logits = model(b, g) + bce_non, kl, mse_global, probs = combined_loss_components( + logits, t, peak_thresh=args.peak_thresh + ) + loss = args.alpha * bce_non + args.beta * kl + args.gamma * mse_global + loss = loss / args.accum_steps + + scaler.scale(loss).backward() + + if (i + 1) % args.accum_steps == 0: + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + with torch.no_grad(): + acc_loss_sum += loss.item() * args.accum_steps + acc_acc_sum += accuracy_percentage( + logits, t, peak_thresh=args.peak_thresh + ) + n_train_batches += 1 + + del b, g, t, logits, probs, loss, bce_non, kl, mse_global + torch.cuda.empty_cache() + + # finalize if leftovers + if n_train_batches % args.accum_steps != 0: + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + train_loss = acc_loss_sum / max(1, n_train_batches) + train_acc = acc_acc_sum / max(1, n_train_batches) + + val_loss, val_acc = evaluate( + model, + va_dl, + device, + alpha=args.alpha, + beta=args.beta, + gamma=args.gamma, + peak_thresh=args.peak_thresh, + ) + print( + f"[Epoch {ep}] train_loss={train_loss:.4f} train_acc={train_acc:.2f}% " + f"val_loss={val_loss:.4f} val_acc={val_acc:.2f}%", + flush=True, + ) + + history["train"].append(train_loss) + history["val"].append(val_loss) + if val_loss < best_val: + best_val = val_loss + torch.save(model.state_dict(), od / "best_model.pt") + print( + f" Saved new best_model.pt (val_loss={val_loss:.4f}, val_acc={val_acc:.2f}%)", + flush=True, + ) + + torch.save(model.state_dict(), od / "last_model.pt") + + fig, ax = plt.subplots() + ax.plot(history["train"], label="train") + ax.plot(history["val"], label="val") + ax.set_xlabel("epoch") + ax.set_ylabel("combined loss") + ax.legend() + fig.savefig(od / "loss_curve.png") + print(f"✅ Done → outputs in {od}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/dpacman/classifier/model/model.py b/dpacman/classifier/torch_model.py similarity index 71% rename from dpacman/classifier/model/model.py rename to dpacman/classifier/torch_model.py index d573624a6629aeb4ca1102e17ad8a716230f9825..8c19fee12fbecf27ac6da4d7187f7bf83cdae34f 100644 --- a/dpacman/classifier/model/model.py +++ b/dpacman/classifier/torch_model.py @@ -1,6 +1,7 @@ import torch from torch import nn + class LocalCNN(nn.Module): def __init__(self, dim: int = 256, kernel_size: int = 3): super().__init__() @@ -13,8 +14,9 @@ class LocalCNN(nn.Module): # x: (batch, L, dim) out = self.conv(x.transpose(1, 2)) # → (batch, dim, L) out = self.act(out) - out = out.transpose(1, 2) # → (batch, L, dim) - return self.ln(out + x) # residual + out = out.transpose(1, 2) # → (batch, L, dim) + return self.ln(out + x) # residual + class CrossModalBlock(nn.Module): def __init__(self, dim: int = 256, heads: int = 8): @@ -25,15 +27,21 @@ class CrossModalBlock(nn.Module): self.ln_b1 = nn.LayerNorm(dim) self.ln_g1 = nn.LayerNorm(dim) - self.ffn_b = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) - self.ffn_g = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) + self.ffn_b = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim) + ) + self.ffn_g = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim) + ) self.ln_b2 = nn.LayerNorm(dim) self.ln_g2 = nn.LayerNorm(dim) # cross attention (binder queries, glm keys/values) self.cross_attn = nn.MultiheadAttention(dim, heads, batch_first=True) self.ln_c1 = nn.LayerNorm(dim) - self.ffn_c = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) + self.ffn_c = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim) + ) self.ln_c2 = nn.LayerNorm(dim) def forward(self, binder: torch.Tensor, glm: torch.Tensor): @@ -63,11 +71,13 @@ class CrossModalBlock(nn.Module): c = self.ln_c2(c + c_ff) return c # (batch, Lb, dim) + class DimCompressor(nn.Module): """ Learnable per-token compressor: maps any in_dim >= out_dim to out_dim (default 256). If in_dim == out_dim, behaves as identity. """ + def __init__(self, in_dim: int, out_dim: int = 256): super().__init__() if in_dim == out_dim: @@ -85,16 +95,19 @@ class DimCompressor(nn.Module): # x: (B, L, in_dim) return self.net(x) + class BindPredictor(nn.Module): - def __init__(self, - # input_dim: int = 256, # OLD: single input dim - binder_input_dim: int = 1280, # NEW: TF (binder) original dim (e.g., 1280) - glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256) - compressed_dim: int = 256, # NEW: learnable compressed dim - hidden_dim: int = 256, - heads: int = 8, - num_layers: int = 4, - use_local_cnn_on_glm: bool = True): + def __init__( + self, + # input_dim: int = 256, # OLD: single input dim + binder_input_dim: int = 1280, # NEW: TF (binder) original dim (e.g., 1280) + glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256) + compressed_dim: int = 256, # NEW: learnable compressed dim + hidden_dim: int = 256, + heads: int = 8, + num_layers: int = 4, + use_local_cnn_on_glm: bool = True, + ): super().__init__() # OLD: # self.proj_binder = nn.Linear(input_dim, hidden_dim) @@ -102,7 +115,7 @@ class BindPredictor(nn.Module): # NEW: learnable compressor for binder → 256, then project to hidden self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim) - self.proj_binder = nn.Linear(compressed_dim, hidden_dim) + self.proj_binder = nn.Linear(compressed_dim, hidden_dim) # GLM side stays 256 → hidden self.proj_glm = nn.Linear(glm_input_dim, hidden_dim) @@ -110,13 +123,13 @@ class BindPredictor(nn.Module): self.use_local_cnn = use_local_cnn_on_glm self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity() - self.layers = nn.ModuleList([ - CrossModalBlock(hidden_dim, heads) for _ in range(num_layers) - ]) + self.layers = nn.ModuleList( + [CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)] + ) self.ln_out = nn.LayerNorm(hidden_dim) # self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities - self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP) + self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP) def forward(self, binder_emb, glm_emb): """ @@ -125,18 +138,20 @@ class BindPredictor(nn.Module): Returns per-nucleotide logits for the GLM sequence: (B, Lg) """ # Binder: learnable compression → 256 → hidden - b = self.binder_compress(binder_emb) # (B, Lb, 256) - b = self.proj_binder(b) # (B, Lb, hidden_dim) + b = self.binder_compress(binder_emb) # (B, Lb, 256) + b = self.proj_binder(b) # (B, Lb, hidden_dim) # GLM: project → hidden, add local CNN context - g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim) + g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim) if self.use_local_cnn: g = self.local_cnn(g) # Cross-modal blocks: update binder states using GLM for layer in self.layers: - b = layer(b, g) # (B, Lb, hidden_dim) + b = layer(b, g) # (B, Lb, hidden_dim) # Predict per-nucleotide logits on the GLM tokens: # return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head) - return self.head(g).squeeze(-1) # NEW: logits (apply sigmoid only in loss/metrics) + return self.head(g).squeeze( + -1 + ) # NEW: logits (apply sigmoid only in loss/metrics) diff --git a/dpacman/classifier/train.py b/dpacman/classifier/train.py new file mode 100644 index 0000000000000000000000000000000000000000..65e735a730679ce5d9e3e4898656d7e14672c5e4 --- /dev/null +++ b/dpacman/classifier/train.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 +import argparse +import numpy as np +import torch +from torch import nn +from model import BindPredictor +from pathlib import Path +from collections import Counter +from sklearn.metrics import roc_auc_score, average_precision_score +from sklearn.decomposition import TruncatedSVD +import sys + +from dpacman.utils.models import set_seed + + +def build_tf_compressed_cache(binder_paths, target_dim=256): + """ + Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array. + """ + unique_paths = sorted(set(binder_paths)) + print( + f"[i] Found {len(unique_paths)} unique TF embedding files to compress.", + flush=True, + ) + # Load all embeddings to determine dimensionality + samples = [] + for p in unique_paths: + arr = np.load(p) + samples.append(arr) + # Determine if reduction needed: assume all have same embedding width + first = samples[0] + orig_dim = first.shape[1] if first.ndim == 2 else 1 + reduction_needed = orig_dim != target_dim + tf_cache = {} + + if reduction_needed: + # Build matrix to fit SVD: we need a 2D matrix per embedding; if lengths vary we can't directly stack. + # We'll do reduction per sequence individually using TruncatedSVD on concatenated flattened features: + # Simplest: for variable lengths, reduce each embedding separately with a learned linear projection. + # Here we fit a single TruncatedSVD on the concatenation of all sequence tokens (flattened) by padding/truncating to a fixed length. + # To avoid complexity, use PCA-like linear projection learned via SVD on mean-pooled vectors: + pooled = [] + for arr in samples: + if arr.ndim == 2: + pooled.append(arr.mean(axis=0)) # (orig_dim,) + else: + pooled.append(arr) # degenerate + pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim) + print( + f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}", + flush=True, + ) + svd = TruncatedSVD(n_components=target_dim, random_state=42) + reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim) + + # For each original embedding, project token-level vectors by multiplying token vector with svd.components_.T + # svd.components_: (target_dim, orig_dim) so projection matrix is (orig_dim, target_dim) + proj_mat = svd.components_.T # (orig_dim, target_dim) + for i, p in enumerate(unique_paths): + arr = samples[i] # shape (L, orig_dim) + if arr.ndim == 1: + arr2 = arr @ proj_mat # (target_dim,) + else: + # project each token: (L, orig_dim) @ (orig_dim, target_dim) -> (L, target_dim) + arr2 = arr @ proj_mat + tf_cache[p] = arr2 # reduced per-token representation + print("[i] Completed compression of TF embeddings.", flush=True) + else: + # already correct dim: just cache originals + print( + f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.", + flush=True, + ) + for i, p in enumerate(unique_paths): + arr = samples[i] + tf_cache[p] = arr + return tf_cache + + +def evaluate(model, dl, device): + model.eval() + all_labels = [] + all_preds = [] + with torch.no_grad(): + for b, g, y in dl: + b = b.to(device) + g = g.to(device) + y = y.to(device) + pred = model(b, g) + all_labels.append(y.cpu()) + all_preds.append(pred.cpu()) + if not all_labels: + return 0.0, 0.0 + y_true = torch.cat(all_labels).numpy() + y_score = torch.cat(all_preds).numpy() + try: + auc = roc_auc_score(y_true, y_score) + except Exception: + auc = 0.0 + try: + ap = average_precision_score(y_true, y_score) + except Exception: + ap = 0.0 + return auc, ap + + +def unpack(data): + binders, glms, labels = zip(*data) + return list(binders), list(glms), list(labels) + + +# ---- main ------------------------------------------------------------ +def main(cfg): + """ + Main method, used to train the model. + """ + # Set seed for reproducibility + set_seed(cfg.seed) + + parser.add_argument("--out_dir", type=str, required=True) + parser.add_argument("--epochs", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + # + print("DEBUG: starting training script with in-line TF compression", flush=True) + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + binder_paths, glm_paths, labels = parse_pair_list(cfg.pair_list) + + if len(labels) == 0: + print("[ERROR] No valid pairs parsed. Exiting.", file=sys.stderr) + sys.exit(1) + + label_counts = Counter(labels) + print( + f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}", + flush=True, + ) + + # build compressed TF cache (reduces to 256 if needed) + tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256) + + # load training data aloiaushasfoiuhasfoiuafasdfoihuaaasdfoiuhasfaaoiufhasfoasasfoiuh + + train_ds = PairDataset(None, tf_compressed_cache=tf_compressed_cache) + val_ds = PairDataset(*subset(val_i), tf_compressed_cache=tf_compressed_cache) + test_ds = PairDataset(*subset(test_i), tf_compressed_cache=tf_compressed_cache) + + print( + f"[i] Train/Val/Test sizes: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}", + flush=True, + ) + if len(train_ds) == 0 or len(val_ds) == 0: + print( + "[ERROR] Train or validation split is empty; cannot proceed.", + file=sys.stderr, + ) + sys.exit(1) + + train_dl = DataLoader( + train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn + ) + val_dl = DataLoader( + val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn + ) + test_dl = DataLoader( + test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn + ) + + model = BindPredictor( + input_dim=256, hidden_dim=256, heads=8, num_layers=3, use_local_cnn_on_glm=True + ) + model = model.to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-3) + loss_fn = nn.BCELoss() + + best_val = -float("inf") + os_out = Path(args.out_dir) + os_out.mkdir(exist_ok=True, parents=True) + + for epoch in range(1, args.epochs + 1): + print(f"[Epoch {epoch}] starting...", flush=True) + model.train() + running_loss = 0.0 + for b, g, y in train_dl: + b = b.to(device) + g = g.to(device) + y = y.to(device) + pred = model(b, g) + loss = loss_fn(pred, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + running_loss += loss.item() * b.size(0) + train_loss = running_loss / len(train_ds) + val_auc, val_ap = evaluate(model, val_dl, device) + print( + f"[Epoch {epoch}] train_loss={train_loss:.4f} val_auc={val_auc:.4f} val_ap={val_ap:.4f}", + flush=True, + ) + + if val_auc > best_val: + best_val = val_auc + torch.save(model.state_dict(), os_out / "best_model.pt") + print( + f"[Epoch {epoch}] Saved new best model with val_auc={val_auc:.4f}", + flush=True, + ) + + torch.save(model.state_dict(), os_out / "last_model.pt") + test_auc, test_ap = evaluate(model, test_dl, device) + print(f"FINAL TEST: AUC={test_auc:.4f} AP={test_ap:.4f}", flush=True) + print(f"[i] Models written to {os_out}/best_model.pt and last_model.pt", flush=True) + + +if __name__ == "__main__": + main() diff --git a/dpacman/data_modules/pair.py b/dpacman/data_modules/pair.py index 8f990fe633d918298b76d15e385f637eab723417..d84a6d591b8be37098a00c50c80b744a987e5126 100644 --- a/dpacman/data_modules/pair.py +++ b/dpacman/data_modules/pair.py @@ -2,341 +2,570 @@ import argparse import numpy as np import torch -from torch import nn -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import Dataset, DataLoader, Sampler from lightning import LightningDataModule -from model import BindPredictor from pathlib import Path -from collections import Counter -from sklearn.metrics import roc_auc_score, average_precision_score -from sklearn.decomposition import TruncatedSVD from multiprocessing import cpu_count -from functools import partial import random -import sys import pandas as pd +import shelve +from torch.nn.utils.rnn import pad_sequence +from typing import List, Iterable, Sequence +import sys +import rootutils +from dpacman.utils import pylogger + +root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +logger = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +class PreBatchedSampler(Sampler[List[int]]): + """ + Yields precomputed batches of indices, e.g. [[3,7,9], [0,1,2], ...]. + Useful when you've already formed batches by length. + """ + + def __init__( + self, + batches: Sequence[Sequence[int]], + shuffle_batch_order: bool = False, + generator=None, + ): + self.batches = [list(b) for b in batches] + self.shuffle_batch_order = shuffle_batch_order + self.generator = generator + + def __iter__(self) -> Iterable[List[int]]: + if self.shuffle_batch_order: + # local copy we can shuffle without touching the original + idxs = list(range(len(self.batches))) + g = self.generator if self.generator is not None else torch.Generator() + perm = torch.randperm(len(idxs), generator=g).tolist() + for i in perm: + yield self.batches[i] + else: + for b in self.batches: + yield b + + def __len__(self) -> int: + return len(self.batches) + + +def compute_tr_lengths_from_shelf( + tr_shelf_path: str, tr_sequences: list[str] +) -> list[int]: + """ + Opens the TR shelf once and returns length for each sequence. + 2D array -> length = shape[0]; 1D array (pooled) -> length = 1. + """ + lengths = [] + with shelve.open(tr_shelf_path, flag="r") as db: + for s in tr_sequences: + arr = np.asarray(db[str(s)]) + if arr.ndim == 1: + lengths.append(1) + else: + lengths.append(int(arr.shape[0])) + return lengths + + +def make_length_batches( + dataset_records: list[dict], + tr_shelf_path: str, + batch_size: int, + drop_last: bool = False, +) -> list[list[int]]: + """ + dataset_records: output of PairDataset._load_and_normalize(...), i.e. list of dicts with + keys: "dna_sequence", "tr_sequence", "scores", ... + Returns a list of batches, each a list of indices, sorted by (dna_len, tr_len). + """ + # DNA length comes from label length + dna_lens = [len(r["scores"]) for r in dataset_records] + tr_seqs = [r["tr_sequence"] for r in dataset_records] -from dpacman.utils import RankedLogger + # TR length requires a quick shelf lookup (done once here) + tr_lens = compute_tr_lengths_from_shelf(tr_shelf_path, tr_seqs) -logger = RankedLogger(__name__, rank_zero_only=True) + # sort indices by (dna_len, tr_len) + idxs = list(range(len(dataset_records))) + idxs.sort(key=lambda i: (dna_lens[i], tr_lens[i])) + + # chunk into fixed-size batches + batches = [idxs[i : i + batch_size] for i in range(0, len(idxs), batch_size)] + if drop_last and len(batches) and len(batches[-1]) < batch_size: + batches.pop() + return batches # ---- dataset --------------------------------------------------------- class PairDataset(Dataset): - def __init__(self, tf_paths, dna_paths, final_df, tf_cache): + def __init__( + self, dataset: pd.DataFrame, norm_value: int = 1333, round_to: int = 4 + ): + """ + Args: + - dataset: a dataset with the needed information: ID, dna_sequence, tr_sequence, scores + - norm_value: max score, which we'll use to divide all the integer scores in "scores" + - round_to: how many decimal places for the numerical score values """ - tf_cache: dict mapping binder_path -> compressed (256-d) tensor/array + self.dataset = self._load_and_normalize(dataset, norm_value, round_to) + self.norm_value = ( + norm_value # what to divide everything in labels by to make it a float + ) + + def _load_and_normalize(self, dataset, norm_value: int, round_to: int): + """ + Labels come in looking like "0,0,0,100,100,133,133,100,100,0,0," + This method turns the labels from strings into floats out to 4 decimal places """ - self.tf_paths, self.dna_paths = tf_paths, dna_paths - self.tf_cache = tf_cache - self.targets = {} - for _, row in final_df.iterrows(): - dna_id = row["dna_id"] - vec = np.array(list(map(float, row["score_sig_r2"].split(","))), dtype=np.float32) - self.targets[dna_id] = vec + # split string into list of strings + dataset["scores"] = dataset["scores"].apply(lambda x: x.split(",")) + # turn list of strings into list of normalized, rounded floats + dataset["scores"] = dataset["scores"].apply( + lambda x: [round(int(y) / norm_value, round_to) for y in x] + ) + + # convert to records for ease of loading + dataset = dataset.to_dict(orient="records") + return dataset def __len__(self): - return len(self.tf_paths) - - def __getitem__(self, i): - b = self.tf_cache[self.tf_paths[i]] # (L_b, D_b) or (D_b,) - if b.ndim==1: b = b[None,:] - g = np.load(self.dna_paths[i]) # (L_g, 256) or (256,) - if g.ndim==1: g = g[None,:] - - stem = Path(self.dna_paths[i]).stem - dna_id = stem.replace("dna_","") - t = self.targets.get(dna_id, np.zeros(g.shape[0],dtype=np.float32)) - - return torch.from_numpy(b).float(), \ - torch.from_numpy(g).float(), \ - torch.from_numpy(t).float() - - + return len(self.dataset) + + def __getitem__(self, idx): + item = self.dataset[idx] + return {**(item if isinstance(item, dict) else {})} + + class PairDataModule(LightningDataModule): def __init__( self, - train_file: str = "data_files/splits/train.csv", - val_file: str = "data_files/splits/val.csv", - test_file: str = "data_files/splits/test.csv", - tokenizer_path="facebook/esm2_t33_650M_UR50D", + train_file: Path | str = "../data_files/splits/train.csv", + val_file: Path | str = "../data_files/splits/val.csv", + test_file: Path | str = "../data_files/splits/test.csv", + tr_shelf_path: ( + Path | str + ) = "../data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf", + dna_shelf_path: ( + Path | str + ) = "../data_files/processed/embeddings/fimo_hits_only/baby_peaks_segmentnt_pernuc_with_onehot.shelf", batch_size: int = 1, num_workers=8, maximize_num_workers=False, debug_run: bool = False, pin_memory: bool = False, + shuffle_train_batch_order: bool = True, ): super().__init__() self.save_hyperparameters() self.debug_run = debug_run - + # Initialize the data files self.train_data_file = train_file self.val_data_file = val_file self.test_data_file = test_file - + # Initialize hyperparameters like batch size self.batch_size = batch_size - self.num_wokers = cpu_count() if maximize_num_workers else min(num_workers, cpu_count()) - - logger.info(f"num_workers={self.num_wokers}") + self.num_workers = ( + cpu_count() if maximize_num_workers else min(num_workers, cpu_count()) + ) + + # Set up ShelfCollator + self.collate = ShelfCollator( + tr_shelf_path=str(tr_shelf_path), + dna_shelf_path=str(dna_shelf_path), + tr_key="tr_sequence", + dna_key="dna_sequence", + dtype=torch.float32, + pad_value=0.0, + ) + self.drop_last = False # or True, your choice + self.shuffle_batch_order = shuffle_train_batch_order # False keep batches deterministic per epoch; set True if you want to shuffle batch order + + logger.info(f"num_workers={self.num_workers}") logger.info("Initialized BinderDecoyDataModule constants") - - def load_and_unpack(self, file_path, lim=None): + + def load_file(self, file_path, lim=None): """ - Load and unpack an input csv whose columns are binder_path,glm_path,label + Load and unpack an input csv whose columns are binder_path,glm_path,label """ df = pd.read_csv(file_path) if lim is not None: df = df[:lim].reset_index(drop=True) - - binder_paths = df["binder_path"].tolist() - glm_paths = df["glm_path"].tolist() - labels = df["label"].tolist() - - return binder_paths, glm_paths, labels - - def setup(self, stage): + + return df[["ID", "dna_sequence", "tr_sequence", "scores"]] + + def setup(self, stage: str | None = None): lim = 5 if self.debug_run else None - if stage=="train": - binder_paths, glm_paths, labels = self.load_file(self.train_data_file, lim=lim) - self.train_dataset = PairDataset(binder_paths, glm_paths, labels) - elif stage=="val": - binder_paths, glm_paths, labels = self.load_file(self.val_data_file, lim=lim) - self.val_dataset = PairDataset(binder_paths, glm_paths, labels) - elif stage=="test": - binder_paths, glm_paths, labels = self.load_file(self.test_data_file, lim=lim) - self.test_dataset = PairDataset(binder_paths, glm_paths, labels) - else: - raise RuntimeError(f"Stage {stage} is not defined. Must be train, val, or test.") - + + # FIT: build train & val (so val exists during training) + if stage in (None, "fit"): + if not hasattr(self, "train_dataset"): + train_df = self.load_file(self.train_data_file, lim=lim) + self.train_dataset = PairDataset(train_df) + self.train_batches = make_length_batches( + dataset_records=self.train_dataset.dataset, + tr_shelf_path=str(self.hparams.tr_shelf_path), + batch_size=self.batch_size, + drop_last=self.drop_last, + ) + self.train_batch_sampler = PreBatchedSampler( + self.train_batches, + shuffle_batch_order=self.shuffle_batch_order, + ) + + if not hasattr(self, "val_dataset"): + val_df = self.load_file(self.val_data_file, lim=lim) + self.val_dataset = PairDataset(val_df) + self.val_batches = make_length_batches( + dataset_records=self.val_dataset.dataset, + tr_shelf_path=str(self.hparams.tr_shelf_path), + batch_size=self.batch_size, + drop_last=False, + ) + self.val_batch_sampler = PreBatchedSampler( + self.val_batches, shuffle_batch_order=False + ) + + # VALIDATE called standalone: ensure val is built + if stage in (None, "validate"): + if not hasattr(self, "val_dataset"): + val_df = self.load_file(self.val_data_file, lim=lim) + self.val_dataset = PairDataset(val_df) + self.val_batches = make_length_batches( + dataset_records=self.val_dataset.dataset, + tr_shelf_path=str(self.hparams.tr_shelf_path), + batch_size=self.batch_size, + drop_last=False, + ) + self.val_batch_sampler = PreBatchedSampler( + self.val_batches, shuffle_batch_order=False + ) + + # TEST phase + if stage in (None, "test"): + if not hasattr(self, "test_dataset"): + test_df = self.load_file(self.test_data_file, lim=lim) + self.test_dataset = PairDataset(test_df) + self.test_batches = make_length_batches( + dataset_records=self.test_dataset.dataset, + tr_shelf_path=str(self.hparams.tr_shelf_path), + batch_size=self.batch_size, + drop_last=False, + ) + self.test_batch_sampler = PreBatchedSampler( + self.test_batches, shuffle_batch_order=False + ) + def train_dataloader(self): return DataLoader( self.train_dataset, - batch_size=self.batch_size, - collate_fn=collate_fn, - num_workers=self.num_wokers, + batch_sampler=self.train_batch_sampler, + collate_fn=self.collate, + num_workers=self.num_workers, + persistent_workers=(self.num_workers > 0), pin_memory=self.hparams.pin_memory, - shuffle=True, ) def val_dataloader(self): return DataLoader( self.val_dataset, - batch_size=self.batch_size, - collate_fn=collate_fn, - num_workers=self.num_wokers, + batch_sampler=self.val_batch_sampler, + collate_fn=self.collate, + num_workers=self.num_workers, + persistent_workers=(self.num_workers > 0), pin_memory=self.hparams.pin_memory, - shuffle=False, ) def test_dataloader(self): return DataLoader( self.test_dataset, - batch_size=self.batch_size, - collate_fn=collate_fn, - num_workers=self.num_wokers, + batch_sampler=self.test_batch_sampler, + collate_fn=self.collate, + num_workers=self.num_workers, + persistent_workers=(self.num_workers > 0), pin_memory=self.hparams.pin_memory, - shuffle=False, ) -def collate_fn(batch): - Bs = [b.shape[0] for b,_,_ in batch] - Gs = [g.shape[0] for _,g,_ in batch] +class ShelfCollator: + """ + Lazily opens TR (binder) and DNA shelves the first time each worker calls __call__. + Expects each item to contain keys: + - "tr_sequence": str (key for TR shelf) + - "dna_sequence": str (key for DNA shelf) + - "scores": list[float] (per-base labels for DNA) + - optional "ID" + Returns a dict with: + - binder_emb: FloatTensor [B, Lb_max, Db] (padded) + - binder_mask: BoolTensor [B, Lb_max] + - glm_emb: FloatTensor [B, Lg_max, Dg] (padded) + - glm_mask: BoolTensor [B, Lg_max] + - labels: FloatTensor [B, Lg_max] (padded, zeros where masked) + - ids, tr_sequences, dna_sequences: lists + """ + + def __init__( + self, + tr_shelf_path: str, + dna_shelf_path: str, + tr_key: str = "tr_sequence", + dna_key: str = "dna_sequence", + dtype: torch.dtype = torch.float32, + pad_value: float = 0.0, + ): + self.tr_path = tr_shelf_path + self.dna_path = dna_shelf_path + self.tr_key = tr_key + self.dna_key = dna_key + self.dtype = dtype + self.pad_value = pad_value + + # opened lazily per worker: + self._tr_db = None + self._dna_db = None + + def _ensure_open(self): + if self._tr_db is None: + self._tr_db = shelve.open(self.tr_path, flag="r") # read-only + if self._dna_db is None: + self._dna_db = shelve.open(self.dna_path, flag="r") + + def __call__(self, batch): + """ + batch: list[dict] from Dataset.__getitem__ + """ + self._ensure_open() + + ids = [b.get("ID", None) for b in batch] + tr_seqs = [b[self.tr_key] for b in batch] + dna_seqs = [b[self.dna_key] for b in batch] + scores_list = [b["scores"] for b in batch] + + # 1) Fetch embeddings lazily from shelves + binder_list = [] + glm_list = [] + binder_lens = [] + glm_lens = [] + + for tr, dna, scores in zip(tr_seqs, dna_seqs, scores_list): + # ----- binder/TR ----- + tr_arr = np.asarray(self._tr_db[str(tr)]) + # ensure 2D: [Lb, Db] (if pooled 1D, make length=1) + if tr_arr.ndim == 1: + tr_arr = tr_arr[None, :] + binder_list.append(torch.from_numpy(tr_arr).to(self.dtype)) + binder_lens.append(tr_arr.shape[0]) + + # ----- DNA / GLM ----- + dna_arr = np.asarray(self._dna_db[str(dna)]) + if dna_arr.ndim == 1: + dna_arr = dna_arr[None, :] + glm_list.append(torch.from_numpy(dna_arr).to(self.dtype)) + glm_lens.append(dna_arr.shape[0]) + + # sanity: scores length should match dna length + if len(scores) != dna_arr.shape[0]: + raise ValueError( + f"Length mismatch for DNA seq: shelf length={dna_arr.shape[0]} " + f"but scores length={len(scores)}" + ) + + # 2) Pad sequences to batch max length + binder_emb = pad_sequence( + binder_list, batch_first=True, padding_value=self.pad_value + ) # [B, Lb_max, Db] + glm_emb = pad_sequence( + glm_list, batch_first=True, padding_value=self.pad_value + ) # [B, Lg_max, Dg] + + binder_lens = torch.as_tensor(binder_lens, dtype=torch.int64) + glm_lens = torch.as_tensor(glm_lens, dtype=torch.int64) + + binder_mask = torch.arange(binder_emb.size(1)).unsqueeze( + 0 + ) < binder_lens.unsqueeze( + 1 + ) # [B, Lb_max] + glm_mask = torch.arange(glm_emb.size(1)).unsqueeze(0) < glm_lens.unsqueeze( + 1 + ) # [B, Lg_max] + + # 3) Collate labels for DNA and pad + labels_list = [torch.tensor(s, dtype=torch.float32) for s in scores_list] + labels = pad_sequence( + labels_list, batch_first=True, padding_value=0.0 + ) # [B, Lg_max] + # (Optional) ensure labels are zeroed beyond mask: + labels = labels * glm_mask.to(labels.dtype) + + return { + "binder_emb": binder_emb, # [B, Lb_max, Db] + "binder_mask": binder_mask, # [B, Lb_max] + "glm_emb": glm_emb, # [B, Lg_max, Dg] + "glm_mask": glm_mask, # [B, Lg_max] + "labels": labels, # [B, Lg_max] + "ID": ids, + "tr_sequence": tr_seqs, + "dna_sequence": dna_seqs, + } + + +def collate_fn(batch, tr_shelf_path, dna_shelf_path): + Bs = [b.shape[0] for b, _, _ in batch] + Gs = [g.shape[0] for _, g, _ in batch] maxB, maxG = max(Bs), max(Gs) def pad_seq(x, L): if x.shape[0] < L: - pad = torch.zeros((L-x.shape[0], x.shape[1]), dtype=x.dtype, device=x.device) + pad = torch.zeros( + (L - x.shape[0], x.shape[1]), dtype=x.dtype, device=x.device + ) return torch.cat([x, pad], dim=0) return x def pad_t(y, L): if y.shape[0] < L: - pad = torch.zeros((L-y.shape[0],), dtype=y.dtype, device=y.device) + pad = torch.zeros((L - y.shape[0],), dtype=y.dtype, device=y.device) return torch.cat([y, pad], dim=0) return y - b_stack = torch.stack([pad_seq(b, maxB) for b,_,_ in batch]) - g_stack = torch.stack([pad_seq(g, maxG) for _,g,_ in batch]) - t_stack = torch.stack([pad_t(t, maxG) for *_,t in batch]) + b_stack = torch.stack([pad_seq(b, maxB) for b, _, _ in batch]) + g_stack = torch.stack([pad_seq(g, maxG) for _, g, _ in batch]) + t_stack = torch.stack([pad_t(t, maxG) for *_, t in batch]) return b_stack, g_stack, t_stack -def build_tf_compressed_cache(binder_paths, target_dim=256): - """ - Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array. - """ - unique_paths = sorted(set(binder_paths)) - logger.info(f"[i] Found {len(unique_paths)} unique TF embedding files to compress.", flush=True) - # Load all embeddings to determine dimensionality - samples = [] - for p in unique_paths: - arr = np.load(p) - samples.append(arr) - # Determine if reduction needed: assume all have same embedding width - first = samples[0] - orig_dim = first.shape[1] if first.ndim == 2 else 1 - reduction_needed = (orig_dim != target_dim) - tf_cache = {} - - if reduction_needed: - # Build matrix to fit SVD: we need a 2D matrix per embedding; if lengths vary we can't directly stack. - # We'll do reduction per sequence individually using TruncatedSVD on concatenated flattened features: - # Simplest: for variable lengths, reduce each embedding separately with a learned linear projection. - # Here we fit a single TruncatedSVD on the concatenation of all sequence tokens (flattened) by padding/truncating to a fixed length. - # To avoid complexity, use PCA-like linear projection learned via SVD on mean-pooled vectors: - pooled = [] - for arr in samples: - if arr.ndim == 2: - pooled.append(arr.mean(axis=0)) # (orig_dim,) - else: - pooled.append(arr) # degenerate - pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim) - logger.info(f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}", flush=True) - svd = TruncatedSVD(n_components=target_dim, random_state=42) - reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim) - - # For each original embedding, project token-level vectors by multiplying token vector with svd.components_.T - # svd.components_: (target_dim, orig_dim) so projection matrix is (orig_dim, target_dim) - proj_mat = svd.components_.T # (orig_dim, target_dim) - for i, p in enumerate(unique_paths): - arr = samples[i] # shape (L, orig_dim) - if arr.ndim == 1: - arr2 = arr @ proj_mat # (target_dim,) - else: - # project each token: (L, orig_dim) @ (orig_dim, target_dim) -> (L, target_dim) - arr2 = arr @ proj_mat - tf_cache[p] = arr2 # reduced per-token representation - logger.info("[i] Completed compression of TF embeddings.", flush=True) - else: - # already correct dim: just cache originals - logger.info(f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.", flush=True) - for i, p in enumerate(unique_paths): - arr = samples[i] - tf_cache[p] = arr - return tf_cache - -def evaluate(model, dl, device): - model.eval() - all_labels = [] - all_preds = [] - with torch.no_grad(): - for b, g, y in dl: - b = b.to(device) - g = g.to(device) - y = y.to(device) - pred = model(b, g) - all_labels.append(y.cpu()) - all_preds.append(pred.cpu()) - if not all_labels: - return 0.0, 0.0 - y_true = torch.cat(all_labels).numpy() - y_score = torch.cat(all_preds).numpy() - try: - auc = roc_auc_score(y_true, y_score) - except Exception: - auc = 0.0 - try: - ap = average_precision_score(y_true, y_score) - except Exception: - ap = 0.0 - return auc, ap - -# ---- main ------------------------------------------------------------ + +# ------------------------ Helpers for main method debugging only ------------------------------------------# +def _peek_batches(dl, n_batches: int = 2, tag: str = "train"): + logger.info(f"\n=== Peek {n_batches} batch(es) from {tag} loader ===") + for i, batch in enumerate(dl): + be = batch["binder_emb"] + bm = batch["binder_mask"] + ge = batch["glm_emb"] + gm = batch["glm_mask"] + y = batch["labels"] + ids = batch.get("ID", [""] * be.size(0)) + + logger.info(f"\n[{tag}] batch {i+1}") + logger.info(f" binder_emb: {tuple(be.shape)} dtype={be.dtype}") + logger.info(f" binder_mask true count: {bm.sum().item()} / {bm.numel()}") + logger.info(f" glm_emb: {tuple(ge.shape)} dtype={ge.dtype}") + logger.info(f" glm_mask true count: {gm.sum().item()} / {gm.numel()}") + logger.info( + f" labels: {tuple(y.shape)} min={y.min().item():.4f} max={y.max().item():.4f}" + ) + logger.info(f" IDs (first 5): {ids[:5]}") + if i + 1 >= n_batches: + break + + +def _warn_on_paths(args): + import os + + for p, label in [ + (args.train_file, "train_file"), + (args.val_file, "val_file"), + (args.test_file, "test_file"), + (args.tr_shelf_path, "tr_shelf_path"), + (args.dna_shelf_path, "dna_shelf_path"), + ]: + if p and not os.path.exists(p): + logger.info(f"{label} does not exist: {p}") + if str(args.tr_shelf_path).endswith(".pkl"): + logger.info( + "Warning: tr_shelf_path ends with .pkl but ShelfCollator expects a shelve DB " + "(e.g., `.shelf`). Pass the correct path via --tr_shelf_path." + ) + + def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--pair_list", type=str, required=True, - help="TSV: binder_path glm_path label") - parser.add_argument("--out_dir", type=str, required=True) - parser.add_argument("--epochs", type=int, default=10) - parser.add_argument("--batch_size", type=int, default=32) - parser.add_argument("--lr", type=float, default=1e-4) - parser.add_argument("--device", type=str, default="cuda") - parser.add_argument("--seed", type=int, default=42) + parser = argparse.ArgumentParser( + description="Peek pre-batched, shelf-backed dataloaders" + ) + parser.add_argument( + "--train_file", + type=str, + default="../data_files/processed/splits/by_dna/babytrain.csv", + ) + parser.add_argument( + "--val_file", + type=str, + default="../data_files/processed/splits/by_dna/babyval.csv", + ) + parser.add_argument( + "--test_file", + type=str, + default="../data_files/processed/splits/by_dna/babytest.csv", + ) + parser.add_argument( + "--tr_shelf_path", + type=str, + default="../data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf", + ) + parser.add_argument( + "--dna_shelf_path", + type=str, + default="../data_files/processed/embeddings/fimo_hits_only/baby_peaks_segmentnt_pernuc_with_onehot.shelf", + ) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument( + "--debug_run", action="store_true", help="limit dataset to a few rows" + ) + parser.add_argument( + "--n_batches", type=int, default=2, help="how many batches to print per split" + ) + parser.add_argument("--shuffle_train_batch_order", action="store_true") args = parser.parse_args() - # reproducibility - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - - logger.info("DEBUG: starting training script with in-line TF compression", flush=True) - logger.info(f"[i] pair_list: {args.pair_list}", flush=True) - logger.info(f"[i] output dir: {args.out_dir}", flush=True) - device = torch.device(args.device if torch.cuda.is_available() else "cpu") - binder_paths, glm_paths, labels = parse_pair_list(args.pair_list) - - if len(labels) == 0: - logger.info("[ERROR] No valid pairs parsed. Exiting.", file=sys.stderr) - sys.exit(1) - - label_counts = Counter(labels) - logger.info(f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}", flush=True) - - # build compressed TF cache (reduces to 256 if needed) - tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256) - - # simple split: 80/10/10 - n = len(labels) - idxs = np.arange(n) - np.random.shuffle(idxs) - train_i = idxs[: int(0.8 * n)] - val_i = idxs[int(0.8 * n): int(0.9 * n)] - test_i = idxs[int(0.9 * n):] - - def subset(idxs): - return [binder_paths[i] for i in idxs], [glm_paths[i] for i in idxs], [labels[i] for i in idxs] - - train_ds = PairDataset(*subset(train_i), tf_compressed_cache=tf_compressed_cache) - val_ds = PairDataset(*subset(val_i), tf_compressed_cache=tf_compressed_cache) - test_ds = PairDataset(*subset(test_i), tf_compressed_cache=tf_compressed_cache) - - logger.info(f"[i] Train/Val/Test sizes: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}", flush=True) - if len(train_ds) == 0 or len(val_ds) == 0: - logger.info("[ERROR] Train or validation split is empty; cannot proceed.", file=sys.stderr) - sys.exit(1) - - train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn) - val_dl = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) - test_dl = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) - - model = BindPredictor(input_dim=256, hidden_dim=256, heads=8, num_layers=3, use_local_cnn_on_glm=True) - model = model.to(device) - optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-3) - loss_fn = nn.BCELoss() - - best_val = -float("inf") - os_out = Path(args.out_dir) - os_out.mkdir(exist_ok=True, parents=True) - - for epoch in range(1, args.epochs + 1): - logger.info(f"[Epoch {epoch}] starting...", flush=True) - model.train() - running_loss = 0.0 - for b, g, y in train_dl: - b = b.to(device) - g = g.to(device) - y = y.to(device) - pred = model(b, g) - loss = loss_fn(pred, y) - optimizer.zero_grad() - loss.backward() - optimizer.step() - running_loss += loss.item() * b.size(0) - train_loss = running_loss / len(train_ds) - val_auc, val_ap = evaluate(model, val_dl, device) - logger.info(f"[Epoch {epoch}] train_loss={train_loss:.4f} val_auc={val_auc:.4f} val_ap={val_ap:.4f}", flush=True) - - if val_auc > best_val: - best_val = val_auc - torch.save(model.state_dict(), os_out / "best_model.pt") - logger.info(f"[Epoch {epoch}] Saved new best model with val_auc={val_auc:.4f}", flush=True) - - torch.save(model.state_dict(), os_out / "last_model.pt") - test_auc, test_ap = evaluate(model, test_dl, device) - logger.info(f"FINAL TEST: AUC={test_auc:.4f} AP={test_ap:.4f}", flush=True) - logger.info(f"[i] Models written to {os_out}/best_model.pt and last_model.pt", flush=True) + _warn_on_paths(args) + + dm = PairDataModule( + train_file=args.train_file, + val_file=args.val_file, + test_file=args.test_file, + tr_shelf_path=args.tr_shelf_path, + dna_shelf_path=args.dna_shelf_path, + batch_size=args.batch_size, + num_workers=args.num_workers, + debug_run=args.debug_run, + shuffle_train_batch_order=args.shuffle_train_batch_order, + pin_memory=False, + ) + + # ---- Train ---- + dm.setup(stage="train") + train_dl = dm.train_dataloader() + _peek_batches(train_dl, n_batches=args.n_batches, tag="train") + + # ---- Val ---- + dm.setup(stage="val") + val_dl = dm.val_dataloader() + _peek_batches(val_dl, n_batches=1, tag="val") # usually enough to sanity-check + + # ---- Test ---- + dm.setup(stage="test") + test_dl = dm.test_dataloader() + _peek_batches(test_dl, n_batches=1, tag="test") + + logger.info("\nAll good") + if __name__ == "__main__": + # (Optional) set a deterministic seed for batch order shuffling + torch.manual_seed(0) + random.seed(0) + np.random.seed(0) + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)s | %(name)s:%(lineno)d | %(message)s", + datefmt="%H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], # stdout, not stderr + force=True, # override any prior config from imported libs + ) + main() diff --git a/dpacman/data_tasks/clean/remap.py b/dpacman/data_tasks/clean/remap.py index d8e576e29f218c627f207f80ccbcb1e6b7ab3ee6..fdc26d386b5813668b0d3c0ae1459da02c1be866 100644 --- a/dpacman/data_tasks/clean/remap.py +++ b/dpacman/data_tasks/clean/remap.py @@ -1,12 +1,13 @@ import pandas as pd from omegaconf import DictConfig from pathlib import Path -import rootutils -import logging import os +import rootutils +from dpacman.utils import pylogger + root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) -logger = logging.getLogger(__name__) +logger = pylogger.RankedLogger(__name__, rank_zero_only=True) def clean_nr(nr_raw_path: Path | str): diff --git a/dpacman/data_tasks/cluster/remap.py b/dpacman/data_tasks/cluster/remap.py index 49e0bee54998bdbca8aef0f08251f7e799db274b..a400998b75dafbf62223a94b5523c33f24081df7 100644 --- a/dpacman/data_tasks/cluster/remap.py +++ b/dpacman/data_tasks/cluster/remap.py @@ -1,6 +1,7 @@ """ -Holds Python methods for clustering Remap DNA sequences. +Holds Python methods for clustering Remap DNA sequences. """ + import argparse import numpy as np import pandas as pd @@ -9,19 +10,35 @@ import random import sys import subprocess from collections import defaultdict -import rootutils -import logging import os import json from omegaconf import DictConfig from hydra.core.hydra_config import HydraConfig -from dpacman.utils.clustering import make_fasta, process_fasta, analyze_clustering_result, run_mmseqs_clustering, cluster_summary +from dpacman.utils.clustering import ( + make_fasta, + process_fasta, + analyze_clustering_result, + run_mmseqs_clustering, + cluster_summary, +) + +import rootutils +from dpacman.utils import pylogger root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) -logger = logging.getLogger(__name__) +logger = pylogger.RankedLogger(__name__, rank_zero_only=True) -def cluster_molecules(fasta_dict, fasta_path, mmseqs_params: DictConfig, output_dir="", path_to_mmseqs="../softwares/mmseqs", moltype="dna", use_gpu=True): + +def cluster_molecules( + fasta_dict, + fasta_path, + mmseqs_params: DictConfig, + output_dir="", + path_to_mmseqs="../softwares/mmseqs", + moltype="dna", + use_gpu=True, +): """ Args: - fasta_dict: dictionary object where the keys are sequence IDs, and the values are sequences @@ -29,43 +46,45 @@ def cluster_molecules(fasta_dict, fasta_path, mmseqs_params: DictConfig, output_ - mmseqs_params: DictConfig of mmseqs hparams - type: molecule type, "dna" or "protein" """ - + # make the fasta logger.info(f"Making fasta at: {fasta_path}") fasta_path = str(make_fasta(fasta_dict, fasta_path)) - + # prepare directories output_dir = str(Path(root) / output_dir) path_to_mmseqs = str(Path(root) / path_to_mmseqs) - + # run mmseqs - dbtype=1 - if moltype=="dna": dbtype=2 - run_mmseqs_clustering(fasta_path, - output_dir, - min_seq_id=mmseqs_params.min_seq_id, - c=mmseqs_params.c, - cov_mode=mmseqs_params.cov_mode, - cluster_mode=mmseqs_params.cluster_mode, - dbtype=dbtype, - path_to_mmseqs=path_to_mmseqs) - - tsv_path = [x for x in os.listdir(output_dir) if x.endswith(".tsv")][0] - clusters = analyze_clustering_result( - fasta_path, Path(output_dir) / tsv_path + dbtype = 1 + if moltype == "dna": + dbtype = 2 + run_mmseqs_clustering( + fasta_path, + output_dir, + min_seq_id=mmseqs_params.min_seq_id, + c=mmseqs_params.c, + cov_mode=mmseqs_params.cov_mode, + cluster_mode=mmseqs_params.cluster_mode, + dbtype=dbtype, + path_to_mmseqs=path_to_mmseqs, ) + + tsv_path = [x for x in os.listdir(output_dir) if x.endswith(".tsv")][0] + clusters = analyze_clustering_result(fasta_path, Path(output_dir) / tsv_path) logger.info(f"Made clusters DataFrame:\n{clusters.head()}") cluster_summary(clusters) + def read_input_data(input_path): """ - Read the data from the input path. + Read the data from the input path. It may be a csv or parquet """ input_path = Path(root) / input_path df = None if str(input_path).endswith(".parquet"): - df = pd.read_parquet(input_path, engine='pyarrow') + df = pd.read_parquet(input_path, engine="pyarrow") elif str(input_path).endswith(".csv"): df = pd.read_csv(input_path) elif str(input_path).endswith(".tsv") or str(input_path).endswith(".txt"): @@ -73,7 +92,8 @@ def read_input_data(input_path): else: raise Exception(f"Cannot read input data from {input_path}: invalid file type") return df - + + def main(cfg: DictConfig): """ Run clustering on Remap protein AND DNA sequences. @@ -87,8 +107,10 @@ def main(cfg: DictConfig): dna_full_cfg = cfg.data_task.dna_full dna_peaks_cfg = cfg.data_task.dna_peaks protein_cfg = cfg.data_task.protein - logger.info(f"Clustering DNA full: {cfg.data_task.cluster_dna_full}. Clustering DNA peaks: {cfg.data_task.cluster_dna_peaks}. Clustering protein: {cfg.data_task.cluster_protein}.") - + logger.info( + f"Clustering DNA full: {cfg.data_task.cluster_dna_full}. Clustering DNA peaks: {cfg.data_task.cluster_dna_peaks}. Clustering protein: {cfg.data_task.cluster_protein}." + ) + # Make fastas dna_full_fasta_path = Path(root) / dna_full_cfg.fasta_path dna_peaks_fasta_path = Path(root) / dna_peaks_cfg.fasta_path @@ -96,49 +118,85 @@ def main(cfg: DictConfig): os.makedirs(dna_full_fasta_path.parent, exist_ok=True) os.makedirs(dna_peaks_fasta_path.parent, exist_ok=True) os.makedirs(protein_fasta_path.parent, exist_ok=True) - + # Make dictioary needed for input to the fasta methods with open(Path(root) / dna_full_cfg.input_map_path, "r") as f: dna_full_fasta_dict = json.load(f) - + with open(Path(root) / dna_peaks_cfg.input_map_path, "r") as f: dna_peaks_fasta_dict = json.load(f) - + with open(Path(root) / protein_cfg.input_map_path, "r") as f: protein_fasta_dict = json.load(f) - - logger.info(f"Loaded DNA seq dict from: {dna_full_cfg.input_map_path}. Size: {len(dna_full_fasta_dict)}") - logger.info(f"Loaded DNA peaks dict from: {dna_peaks_cfg.input_map_path}. Size: {len(dna_peaks_fasta_dict)}") - logger.info(f"Loaded TR (protein) seq dict from: {protein_cfg.input_map_path}. Size: {len(protein_fasta_dict)}") - + + logger.info( + f"Loaded DNA seq dict from: {dna_full_cfg.input_map_path}. Size: {len(dna_full_fasta_dict)}" + ) + logger.info( + f"Loaded DNA peaks dict from: {dna_peaks_cfg.input_map_path}. Size: {len(dna_peaks_fasta_dict)}" + ) + logger.info( + f"Loaded TR (protein) seq dict from: {protein_cfg.input_map_path}. Size: {len(protein_fasta_dict)}" + ) + # Build hash-sets once (drop NaNs to avoid weird matches) - dna_ids = set(df["dna_seqid"].dropna()) - peak_ids = set(df["peak_seqid"].dropna()) - tr_ids = set(df["tr_seqid"].dropna()) + dna_ids = set(df["dna_seqid"].dropna()) + peak_ids = set(df["peak_seqid"].dropna()) + tr_ids = set(df["tr_seqid"].dropna()) # Iterate only the intersection (fast when allowed << dict size) - dna_full_fasta_dict = {k: dna_full_fasta_dict[k] for k in (dna_full_fasta_dict.keys() & dna_ids)} - dna_peaks_fasta_dict = {k: dna_peaks_fasta_dict[k] for k in (dna_peaks_fasta_dict.keys() & peak_ids)} - protein_fasta_dict = {k: protein_fasta_dict[k] for k in (protein_fasta_dict.keys() & tr_ids)} - - logger.info(f"Filtered dictionaries to only sequences in the filtered training data.") - logger.info(f"Total DNA sequences: {len(dna_full_fasta_dict)}. Total peak sequences: {len(dna_peaks_fasta_dict)}. Total protein sequences: {len(protein_fasta_dict)}") - + dna_full_fasta_dict = { + k: dna_full_fasta_dict[k] for k in (dna_full_fasta_dict.keys() & dna_ids) + } + dna_peaks_fasta_dict = { + k: dna_peaks_fasta_dict[k] for k in (dna_peaks_fasta_dict.keys() & peak_ids) + } + protein_fasta_dict = { + k: protein_fasta_dict[k] for k in (protein_fasta_dict.keys() & tr_ids) + } + + logger.info( + f"Filtered dictionaries to only sequences in the filtered training data." + ) + logger.info( + f"Total DNA sequences: {len(dna_full_fasta_dict)}. Total peak sequences: {len(dna_peaks_fasta_dict)}. Total protein sequences: {len(protein_fasta_dict)}" + ) + if cfg.data_task.cluster_dna_full: logger.info(f"Clustering DNA full sequences, with context") - cluster_molecules(dna_full_fasta_dict, dna_full_fasta_path, - mmseqs_params=dna_full_cfg.mmseqs, output_dir=dna_full_cfg.output_dir, path_to_mmseqs=cfg.data_task.path_to_mmseqs, moltype="dna") - + cluster_molecules( + dna_full_fasta_dict, + dna_full_fasta_path, + mmseqs_params=dna_full_cfg.mmseqs, + output_dir=dna_full_cfg.output_dir, + path_to_mmseqs=cfg.data_task.path_to_mmseqs, + moltype="dna", + ) + if cfg.data_task.cluster_dna_peaks: logger.info(f"Clustering DNA peak sequences") - cluster_molecules(dna_peaks_fasta_dict, dna_peaks_fasta_path, - mmseqs_params=dna_peaks_cfg.mmseqs, output_dir=dna_peaks_cfg.output_dir, path_to_mmseqs=cfg.data_task.path_to_mmseqs, moltype="dna") - + cluster_molecules( + dna_peaks_fasta_dict, + dna_peaks_fasta_path, + mmseqs_params=dna_peaks_cfg.mmseqs, + output_dir=dna_peaks_cfg.output_dir, + path_to_mmseqs=cfg.data_task.path_to_mmseqs, + moltype="dna", + ) + if cfg.data_task.cluster_protein: logger.info("Clustering protein sequences.") - cluster_molecules(protein_fasta_dict, protein_fasta_path, mmseqs_params=protein_cfg.mmseqs, output_dir=protein_cfg.output_dir, path_to_mmseqs=cfg.data_task.path_to_mmseqs, moltype="protein") + cluster_molecules( + protein_fasta_dict, + protein_fasta_path, + mmseqs_params=protein_cfg.mmseqs, + output_dir=protein_cfg.output_dir, + path_to_mmseqs=cfg.data_task.path_to_mmseqs, + moltype="protein", + ) logger.info("Clustering pipeline complete") - + + if __name__ == "__main__": main() diff --git a/dpacman/data_tasks/embeddings/__init__.py b/dpacman/data_tasks/embeddings/__init__.py index 5998f502d7d9d3390569c90636df6cf89cf7063c..fcdb263f9f9652aa11a99fb7ada3637df0cc8732 100644 --- a/dpacman/data_tasks/embeddings/__init__.py +++ b/dpacman/data_tasks/embeddings/__init__.py @@ -1,24 +1,33 @@ from .embedders import ( - CaduceusEmbedder, - DNABertEmbedder, - NucleotideTransformerEmbedder, - GPNEmbedder, - SegmentNTEmbedder, + CaduceusEmbedder, + DNABertEmbedder, + NucleotideTransformerEmbedder, + GPNEmbedder, + SegmentNTEmbedder, ESMEmbedder, ESMDBPEmbedder, - ProGenEmbedder + ProGenEmbedder, ) + def get_embedder(name, device, for_dna=True): name = name.lower() if for_dna: - if name=="caduceus": return CaduceusEmbedder(device) - if name=="dnabert": return DNABertEmbedder(device) - if name=="nucleotide": return NucleotideTransformerEmbedder(device) - if name=="gpn": return GPNEmbedder(device) - if name=="segmentnt": return SegmentNTEmbedder(device) + if name == "caduceus": + return CaduceusEmbedder(device) + if name == "dnabert": + return DNABertEmbedder(device) + if name == "nucleotide": + return NucleotideTransformerEmbedder(device) + if name == "gpn": + return GPNEmbedder(device) + if name == "segmentnt": + return SegmentNTEmbedder(device) else: - if name in ("esm",): return ESMEmbedder(device) - if name in ("esm-dbp","esm_dbp"): return ESMDBPEmbedder(device) - if name=="progen": return ProGenEmbedder(device) - raise ValueError(f"Unknown model {name} (for_dna={for_dna})") \ No newline at end of file + if name in ("esm",): + return ESMEmbedder(device) + if name in ("esm-dbp", "esm_dbp"): + return ESMDBPEmbedder(device) + if name == "progen": + return ProGenEmbedder(device) + raise ValueError(f"Unknown model {name} (for_dna={for_dna})") diff --git a/dpacman/data_tasks/embeddings/dna.py b/dpacman/data_tasks/embeddings/dna.py index daf6005270a358583dd8eac9bc8e9e80a150b53f..acc485afb0a5115ed0097fd1b363d9a2b01e6b21 100644 --- a/dpacman/data_tasks/embeddings/dna.py +++ b/dpacman/data_tasks/embeddings/dna.py @@ -1,52 +1,80 @@ +""" +Embed DNA sequences from ReMap peaks. +""" + from .utils import pad_token_embeddings, embed_and_save from dpacman.data_tasks.embeddings import get_embedder -import logging -import rootutils -import os +import os import torch import json import pandas as pd from pathlib import Path from omegaconf import DictConfig +import rootutils +from dpacman.utils import pylogger root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) -logger = logging.getLogger(__name__) +logger = pylogger.RankedLogger(__name__, rank_zero_only=True) + def main(cfg: DictConfig): - logger.info(f"Making embeddings using {cfg.data_task.chrom_model} for dna sequences at {cfg.data_task.input_file}") + logger.info( + f"Making embeddings using {cfg.data_task.chrom_model} for dna sequences at {cfg.data_task.input_file}" + ) # make out dir if necessary out_dir = Path(root) / cfg.data_task.out_dir os.makedirs(out_dir, exist_ok=True) - + # set device device = "cpu" - if cfg.data_task.device=="gpu": + if cfg.data_task.device == "gpu": if torch.cuda.is_available(): device = "cuda" logger.info(f"Using device: {device}") - # read the input file + # read the input file input_file = Path(root) / cfg.data_task.input_file if str(input_file).endswith(".json"): - # load the json and isolate the sequences and ids + # load the json and isolate the sequences and ids with open(input_file, "r") as f: d = json.load(f) df = pd.DataFrame.from_dict(d, orient="index").reset_index() - df.columns = ["seq_id","sequence"] - + df.columns = ["seq_id", "sequence"] + + if cfg.data_task.debug: + logger.info(f"DEBUG MODE. Only embedding 5 sequences") + df = df.sample(n=5, random_state=42).reset_index(drop=True) + + # crucial: sort by SIZE so that we are padding things in a way that will support pre-batching + df["sequence_length"] = df["sequence"].str.len() + df = df.sort_values(by="sequence_length", ascending=True).reset_index(drop=True) + # turn into list of sequences and IDs peak_seqs = df["sequence"].tolist() peak_ids = df["seq_id"].tolist() - logger.info(f"Embedding {len(peak_seqs)} binding peak sequences from processed remap data") - + logger.info( + f"Embedding {len(peak_seqs)} binding peak sequences from processed remap data" + ) + # Get the DNA embedder dna_embedder = get_embedder(cfg.data_task.chrom_model, device, for_dna=True) - out_peaks = out_dir/ f"peaks_{cfg.data_task.chrom_model}.npy" - embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks) + logger.info(f"Device of embedding model: {dna_embedder.device}") + out_peaks = str(out_dir / f"peaks_{cfg.data_task.chrom_model}.pkl") + if cfg.data_task.debug: + out_peaks = out_peaks.replace(".pkl", "_debug.pkl") + + embed_and_save( + peak_seqs, + peak_ids, + dna_embedder, + out_peaks, + batch_size=cfg.data_task.batch_size, + ) + + logger.info(f"Finished embedding DNA sequences. Saved to: {out_peaks}") + - logger.info("Finished embedding DNA sequences.") - -if __name__=="__main__": - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/dpacman/data_tasks/embeddings/embedders.py b/dpacman/data_tasks/embeddings/embedders.py index c8c5cff26aacd39e951b68cfd97e388e3e136d57..3cec1ba4d06f902dc5b0fdb7710ff8da664c01a8 100644 --- a/dpacman/data_tasks/embeddings/embedders.py +++ b/dpacman/data_tasks/embeddings/embedders.py @@ -14,23 +14,24 @@ Usage example (DNA + protein in one go): --out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \ --device cuda """ -import os -import re -import argparse -import json + import numpy as np from pathlib import Path import torch from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, pipeline -from Bio import SeqIO import time -import pandas as pd import esm from tqdm.auto import tqdm -import logging, math +from sklearn.preprocessing import OneHotEncoder +import math +import rootutils +from dpacman.utils import pylogger +root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +logger = pylogger.RankedLogger(__name__, rank_zero_only=True) # ---- model wrappers ---- + class CaduceusEmbedder: def __init__(self, device, chunk_size=131_072, overlap=0): """ @@ -42,14 +43,16 @@ class CaduceusEmbedder: self.tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True ) - self.model = AutoModel.from_pretrained( - model_name, trust_remote_code=True - ).to(device).eval() - self.device = device + self.model = ( + AutoModel.from_pretrained(model_name, trust_remote_code=True) + .to(device) + .eval() + ) + self.device = device self.chunk_size = chunk_size - self.step = chunk_size - overlap + self.step = chunk_size - overlap - def embed(self, seqs): + def embed(self, seqs, batch_size=1): """ seqs: List[str] of DNA sequences (each <= chunk_size for this test) returns: np.ndarray of shape (N, L, D), raw per‐token embeddings @@ -70,20 +73,21 @@ class CaduceusEmbedder: # return np.stack(outputs, axis=0) # (N, L, D) outputs = [] - for seq in tqdm(seqs, total=len(seqs), desc="DNA: Caduceus", dynamic_ncols=True): + for seq in tqdm( + seqs, total=len(seqs), desc="DNA: Caduceus", dynamic_ncols=True + ): toks = self.tokenizer( seq, return_tensors="pt", padding=False, truncation=True, - max_length=self.chunk_size + max_length=self.chunk_size, ).to(self.device) with torch.no_grad(): out = self.model(**toks).last_hidden_state # (1, L, D) - outputs.append(out.cpu().numpy()[0]) # (L, D) + outputs.append(out.cpu().numpy()[0]) # (L, D) return outputs # list of variable-length (L_i, D) arrays - def benchmark(self, lengths=None): """ Time embedding on single-sequence of various lengths. @@ -104,41 +108,85 @@ class CaduceusEmbedder: t1 = time.perf_counter() print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms") + class SegmentNTEmbedder: def __init__(self, device): - self.tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True) - self.model = AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True).to(device).eval() + self.tokenizer = AutoTokenizer.from_pretrained( + "InstaDeepAI/segment_nt", trust_remote_code=True + ) + self.model = ( + AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True) + .to(device) + .eval() + ) self.device = device def _adjust_length(self, input_ids): + """ + Pads the length so it's divisible by 4; this is needed to get through the BPNet + """ bs, L = input_ids.shape excl = L - 1 remainder = (excl) % 4 if remainder != 0: pad_needed = 4 - remainder - pad_tensor = torch.full((bs, pad_needed), self.tokenizer.pad_token_id, dtype=input_ids.dtype, device=input_ids.device) + pad_tensor = torch.full( + (bs, pad_needed), + self.tokenizer.pad_token_id, + dtype=input_ids.dtype, + device=input_ids.device, + ) input_ids = torch.cat([input_ids, pad_tensor], dim=1) return input_ids - def embed(self, seqs, batch_size=16): + def embed(self, seqs, batch_size=1, log_every_pct=5, pooling=False): """ seqs: List[str] - Returns: np.ndarray of shape (N, D) + Returns: Dict[str, np.ndarray] + - pooling=True -> {seq: (D,)} + - pooling=False -> {seq: (L-1, D)} (excludes CLS, retains padding/truncation) """ - all_embeddings = [] - for i in range(0, len(seqs), batch_size): + n = len(seqs) + if n == 0: + return {} + + # Progress checkpoints: 5%, 10%, ..., 100% + steps = list(range(log_every_pct, 101, log_every_pct)) + checkpoints = [max(1, math.ceil(n * p / 100)) for p in steps] + ck_idx = 0 + processed = 0 + + # (Optional) quick info; uses logger if provided, else print + try: + max_len = max(len(s) for s in seqs) + msg = ( + f"Max length (will be padded/truncated to tokenizer setting): {max_len}" + ) + (logger.info if logger is not None else print)(msg) + except Exception: + pass + + out = {} # seq -> embedding + + for i in range(0, n, batch_size): batch_seqs = seqs[i : i + batch_size] + encoded = self.tokenizer.batch_encode_plus( batch_seqs, return_tensors="pt", padding=True, truncation=True, + max_length=1998, # keep your existing cap ) - input_ids = encoded["input_ids"].to(self.device) # (B, L) - attention_mask = input_ids != self.tokenizer.pad_token_id + orig_len = encoded["input_ids"].shape[1] + + input_ids = encoded["input_ids"].to(self.device) # (B, L) + logger.info(f"input_ids.shape: {input_ids.shape}") + # (Re)compute mask after any length adjustment input_ids = self._adjust_length(input_ids) - attention_mask = (input_ids != self.tokenizer.pad_token_id) + logger.info(f"after adjusting length: input_ids.shape: {input_ids.shape}") + attention_mask = input_ids != self.tokenizer.pad_token_id with torch.no_grad(): outs = self.model( @@ -147,36 +195,97 @@ class SegmentNTEmbedder: output_hidden_states=True, return_dict=True, ) - if hasattr(outs, "hidden_states") and outs.hidden_states is not None: - last_hidden = outs.hidden_states[-1] # (B, L, D) + + last_hidden = ( + outs.hidden_states[-1] + if getattr(outs, "hidden_states", None) is not None + else outs.last_hidden_state + ) # (B, L, D) + logger.info(f"last_hidden.shape: {last_hidden.shape}") + + # Exclude CLS token (assumed first position) + last_hidden = last_hidden[ + :, 1:orig_len, : + ] # keep only CLS-dropped original positions. Exclude the pads + logger.info( + f"after cutting first position: last_hidden.shape: {last_hidden.shape}" + ) + + if pooling: + # Match your original behavior: simple mean over tokens (no mask) + pooled = last_hidden.mean(dim=1) # (B, D) + pooled_np = pooled.detach().cpu().numpy() + for j, s in enumerate(batch_seqs): + out[s] = pooled_np[j] else: - last_hidden = outs.last_hidden_state # fallback + # Keep per-token embeddings (still padded/truncated) + emb_np = last_hidden.detach().cpu().numpy() # (B, L-1, D) + for j, s in enumerate(batch_seqs): + out[s] = emb_np[j] + + processed += len(batch_seqs) - # Exclude CLS token if present (assume first token) and pool - pooled = last_hidden[:, 1:, :].mean(dim=1) # (B, D) - all_embeddings.append(pooled.cpu().numpy()) + # Log only the highest checkpoint crossed this batch + while ck_idx < len(checkpoints) and processed >= checkpoints[ck_idx]: + pct = steps[ck_idx] + msg = f"[embed] {processed}/{n} ({pct}%)" + try: + (logger.info if logger is not None else print)(msg) + except Exception: + print(msg, flush=True) + ck_idx += 1 - # release fragmentation - torch.cuda.empty_cache() + # reduce CUDA memory fragmentation (safe no-op on CPU) + try: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except Exception: + pass - return np.vstack(all_embeddings) # (N, D) + return out class DNABertEmbedder: def __init__(self, device): - self.tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True) - self.model = AutoModel.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True).to(device) - self.device = device + self.tokenizer = AutoTokenizer.from_pretrained( + "zhihan1996/DNA_bert_6", trust_remote_code=True + ) + self.model = AutoModel.from_pretrained( + "zhihan1996/DNA_bert_6", trust_remote_code=True + ).to(device) + self.device = device - def embed(self, seqs): + def embed(self, seqs, batch_size=1): embs = [] for s in seqs: - tokens = self.tokenizer(s, return_tensors="pt", padding=True)["input_ids"].to(self.device) + tokens = self.tokenizer(s, return_tensors="pt", padding=True)[ + "input_ids" + ].to(self.device) with torch.no_grad(): out = self.model(tokens).last_hidden_state.mean(1) embs.append(out.cpu().numpy()) return np.vstack(embs) + +class OneHotEmbedder: + """ + Simple one-hot encoder as a baseline + """ + + def __init__(self, device=None): + self.nucleotides = [list("ACTGN")] + self.model = OneHotEncoder(categories=self.nucleotides, dtype=int) + + def embed(self, seqs, batch_size=1): + out = {} + for s in seqs: + # tokenize + tokens = np.array(list(s)).reshape(-1, 1) + embedding = self.model.fit_transform(tokens).toarray() + out[s] = embedding + return out + + class NucleotideTransformerEmbedder: def __init__(self, device): # HF “feature-extraction” returns a list of (L, D) arrays for each input @@ -184,18 +293,21 @@ class NucleotideTransformerEmbedder: self.pipe = pipeline( "feature-extraction", model="InstaDeepAI/nucleotide-transformer-500m-1000g", - device= -1 if device=="cpu" else 0 # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0} + device=( + -1 if device == "cpu" else 0 + ), # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0} ) - def embed(self, seqs): + def embed(self, seqs, batch_size=1): """ seqs: List[str] of raw DNA sequences returns: (N, D) array, one D-dim vector per sequence """ all_embeddings = self.pipe(seqs, truncation=True, padding=True) # all_embeddings is a List of shape (L, D) arrays - pooled = [ np.mean(x, axis=0) for x in all_embeddings ] - return np.vstack(pooled) + pooled = [np.mean(x, axis=0) for x in all_embeddings] + return np.vstack(pooled) + class ESMEmbedder: def __init__(self, device, model_name="esm2_t33_650M_UR50D"): @@ -211,7 +323,9 @@ class ESMEmbedder: self.batch_converter = self.alphabet.get_batch_converter() self.model.to(device).eval() # determine max length: esm2 models vary; use default 1024 for esm1b - self.max_len = 4096 if self.is_esm2 else 1024 # adjust if your esm2 variant has explicit limit + self.max_len = ( + 4096 if self.is_esm2 else 1024 + ) # adjust if your esm2 variant has explicit limit # for chunking: reserve 2 tokens if model uses BOS/EOS self.chunk_size = self.max_len - 2 self.overlap = self.chunk_size // 4 # 25% overlap to smooth boundaries @@ -222,6 +336,7 @@ class ESMEmbedder: """ if len(seq) <= self.chunk_size: return [seq] + logger.info(f"Calling chunk sequence") step = self.chunk_size - self.overlap chunks = [] for i in range(0, len(seq), step): @@ -231,12 +346,12 @@ class ESMEmbedder: chunks.append(chunk) return chunks - def embed(self, seqs): + def embed(self, seqs, batch_size=1, avg=False): """ seqs: List[str] of protein sequences. - Returns: np.ndarray of shape (N, D) pooled per-sequence embeddings. + Returns: np.ndarray of: shape (N, D) pooled per-sequence embeddings if avg true; shape (N, L, D) otherwise """ - all_embeddings = [] + all_embeddings = {} for i, seq in enumerate(seqs): chunks = self._chunk_sequence(seq) chunk_vecs = [] @@ -250,25 +365,30 @@ class ESMEmbedder: reps = results["representations"][33] # (1, L, D) # remove BOS/EOS if present: take 1:-1 if length permits if reps.size(1) > 2: - rep = reps[:, 1:-1].mean(1) # (1, D) - else: - rep = reps.mean(1) # fallback + rep = reps[:, 1:-1] # (L, D) + if avg: + rep = reps.mean(1) # (1, D) chunk_vecs.append(rep.squeeze(0)) # (D,) + # if we did NOT have to chunk (sequence <= max lenth) if len(chunk_vecs) == 1: seq_vec = chunk_vecs[0] + # if we DID hav eto chunk (sequence > max length) else: # average chunk vectors stacked = torch.stack(chunk_vecs, dim=0) # (num_chunks, D) seq_vec = stacked.mean(0) - all_embeddings.append(seq_vec.cpu().numpy()) - return np.vstack(all_embeddings) # (N, D) + all_embeddings[seq] = seq_vec.cpu().numpy() + return all_embeddings # (N, D) + class ESMDBPEmbedder: def __init__(self, device): base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S() model_path = ( Path(__file__).resolve().parent.parent - / "pretrained" / "ESM-DBP" / "ESM-DBP.model" + / "pretrained" + / "ESM-DBP" + / "ESM-DBP.model" ) checkpoint = torch.load(model_path, map_location="cpu") clean_sd = {} @@ -300,7 +420,7 @@ class ESMDBPEmbedder: chunks.append(chunk) return chunks - def embed(self, seqs): + def embed(self, seqs, batch_size=1): all_embeddings = [] for i, seq in enumerate(seqs): chunks = self._chunk_sequence(seq) @@ -325,6 +445,7 @@ class ESMDBPEmbedder: all_embeddings.append(seq_vec.cpu().numpy()) return np.vstack(all_embeddings) + class GPNEmbedder: def __init__(self, device): model_name = "songlab/gpn-msa-sapiens" @@ -334,18 +455,16 @@ class GPNEmbedder: self.model.eval() self.device = device - def embed(self, seqs): + def embed(self, seqs, batch_size=1): inputs = self.tokenizer( - seqs, - return_tensors="pt", - padding=True, - truncation=True + seqs, return_tensors="pt", padding=True, truncation=True ).to(self.device) with torch.no_grad(): last_hidden = self.model(**inputs).last_hidden_state return last_hidden.mean(dim=1).cpu().numpy() + class ProGenEmbedder: def __init__(self, device): model_name = "jinyuan22/ProGen2-base" @@ -353,13 +472,10 @@ class ProGenEmbedder: self.model = AutoModel.from_pretrained(model_name).to(device).eval() self.device = device - def embed(self, seqs): + def embed(self, seqs, batch_size=1): inputs = self.tokenizer( - seqs, - return_tensors="pt", - padding=True, - truncation=True + seqs, return_tensors="pt", padding=True, truncation=True ).to(self.device) with torch.no_grad(): last_hidden = self.model(**inputs).last_hidden_state - return last_hidden.mean(dim=1).cpu().numpy() \ No newline at end of file + return last_hidden.mean(dim=1).cpu().numpy() diff --git a/dpacman/data_tasks/embeddings/protein.py b/dpacman/data_tasks/embeddings/protein.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..eaea5e3dfb94d0eed3be0a4a97376038712105d1 100644 --- a/dpacman/data_tasks/embeddings/protein.py +++ b/dpacman/data_tasks/embeddings/protein.py @@ -0,0 +1,77 @@ +from .utils import embed_and_save +from dpacman.data_tasks.embeddings import get_embedder + +import os +import torch +import json +import pandas as pd +from pathlib import Path +from omegaconf import DictConfig +import rootutils +from dpacman.utils import pylogger + +root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +logger = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +def main(cfg: DictConfig): + logger.info( + f"Making embeddings using {cfg.data_task.prot_model} for protein sequences at {cfg.data_task.input_file}" + ) + # make out dir if necessary + out_dir = Path(root) / cfg.data_task.out_dir + os.makedirs(out_dir, exist_ok=True) + + # set device + device = "cpu" + if cfg.data_task.device == "gpu": + if torch.cuda.is_available(): + device = "cuda" + logger.info(f"Using device: {device}") + + # read the input file + input_file = Path(root) / cfg.data_task.input_file + if str(input_file).endswith(".json"): + # load the json and isolate the sequences and ids + with open(input_file, "r") as f: + d = json.load(f) + + df = pd.DataFrame.from_dict(d, orient="index").reset_index() + df.columns = ["seq_id", "sequence"] + + if cfg.data_task.debug: + logger.info(f"DEBUG MODE. Only embedding 5 sequences") + df = df.sample(n=5, random_state=42).reset_index(drop=True) + + # crucial: sort by SIZE so that we are padding things in a way that will support pre-batching + df["sequence_length"] = df["sequence"].str.len() + df = df.sort_values(by="sequence_length", ascending=True).reset_index(drop=True) + + # turn into list of sequences and IDs + tr_seqs = df["sequence"].tolist() + tr_ids = df["seq_id"].tolist() + logger.info( + f"Embedding {len(tr_ids)} transcriptional regulators (TRs) from processed remap data" + ) + + # Get the DNA embedder + prot_embedder = get_embedder(cfg.data_task.prot_model, device, for_dna=False) + logger.info(f"Device of embedding model: {prot_embedder.device}") + out_trs = str(out_dir / f"trs_{cfg.data_task.prot_model}.pkl") + if cfg.data_task.debug: + out_trs = out_trs.replace(".pkl", "_debug.pkl") + + embed_and_save( + tr_seqs, + tr_ids, + prot_embedder, + out_trs, + batch_size=cfg.data_task.batch_size, + save_as_shelf=cfg.data_task.save_as_shelf, + ) + + logger.info(f"Finished embedding protein sequences. Saved to: {out_trs}") + + +if __name__ == "__main__": + main() diff --git a/dpacman/data_tasks/embeddings/utils.py b/dpacman/data_tasks/embeddings/utils.py index 417022aba8438351bbaef802b96fd63f80e00f16..38d912713e364fc27c6320428263cfe12536dd56 100644 --- a/dpacman/data_tasks/embeddings/utils.py +++ b/dpacman/data_tasks/embeddings/utils.py @@ -1,7 +1,27 @@ """ Utility funcitons related to creating embeddings """ + import numpy as np +from pathlib import Path +import pickle +import shelve +import rootutils +from dpacman.utils import pylogger + +root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +logger = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +def pkl_to_shelf(pkl_path: str, shelf_path: str): + # WARNING: this will load the original pickle once. + with open(pkl_path, "rb") as f: + data = pickle.load(f) # {sequence_str: np.ndarray or list} + with shelve.open(shelf_path, flag="n", writeback=False) as db: + for k, v in data.items(): + arr = np.asarray(v) # ensure ndarray + db[str(k)] = arr + def pad_token_embeddings(list_of_arrays, pad_value=0.0): """ @@ -21,27 +41,93 @@ def pad_token_embeddings(list_of_arrays, pad_value=0.0): mask[i, :L] = True return padded, mask -def embed_and_save(seqs, ids, embedder, out_path): + +def _to_numpy(x): + """Best-effort: convert torch.Tensor or arraylikes to np.ndarray (CPU).""" + try: + import torch + + if isinstance(x, torch.Tensor): + return x.detach().cpu().numpy() + except Exception: + pass + return np.asarray(x) + + +def embed_and_save( + seqs, ids, embedder, out_path, batch_size=1, save_as_shelf: bool = True +): """ - Using the passed embedder, make embeddings + Using the passed embedder, make embeddings and store as a pickle mapping: + {sequence (str): embedding (np.ndarray)} + + Notes: + - If multiple entries share the exact same sequence string, the *last one wins* in the embedder. + - Validates that every requested sequence has an embedding. """ - embs = embedder.embed(seqs) - - # Decide whether we got variable-length per-token outputs (list of (L, D)) - is_variable_token = isinstance(embs, (list, tuple)) and len(embs) > 0 and hasattr(embs[0], "shape") and embs[0].ndim == 2 - - if is_variable_token: - # pad to (N, L_max, D) + mask - padded, mask = pad_token_embeddings(embs) - # Save both embeddings and mask together in an .npz for convenience - np.savez_compressed(out_path.with_suffix(".caduceus.npz"), - embeddings=padded, - mask=mask, - ids=np.array(ids, dtype=object), - seqs=np.array(seqs, dtype=object)) - else: - # fixed shape output, e.g., pooled (N, D) - array = np.vstack(embs) if isinstance(embs, list) else embs - np.save(out_path, array) - with open(out_path.with_suffix(".ids"), "w") as f: - f.write("\n".join(ids)) + out_path = Path(out_path) + pkl_path = out_path.with_suffix(".pkl") + + # 1) Run the embedder (expects dict: {seq: embedding}) + embs_dict = embedder.embed(seqs, batch_size=batch_size) + if not isinstance(embs_dict, dict): + raise TypeError(f"Expected dict from embedder.embed, got {type(embs_dict)}") + + # 2) Detect duplicates in the input order + seen, dupes = set(), 0 + for s in seqs: + if s in seen: + dupes += 1 + seen.add(s) + if dupes: + msg = ( + f"[embed_and_save] Warning: {dupes} duplicate sequence(s) in input; " + f"pickle will contain one entry per unique sequence." + ) + try: + (logger.info if logger is not None else print)(msg) + except Exception: + print(msg, flush=True) + + # 3) Build ordered mapping (respect input order; last occurrence already reflected in embs_dict) + mapping = {} + missing = [] + for s in seqs: + if s in mapping: + continue # already stored (keep one per unique sequence) + e = embs_dict.get(s) + if e is None: + missing.append(s) + continue + mapping[s] = _to_numpy(e) + + if missing: + raise KeyError( + f"Embedder did not return embeddings for {len(missing)} sequence(s). " + f"Example: {missing[0][:50]}..." + ) + + # 4) Save pickle + with open(pkl_path, "wb") as f: + pickle.dump(mapping, f, protocol=5) + logger.info(f"Saved as pkl at {pkl_path}") + + # 5) Optional tiny manifest + try: + n = len(mapping) + ndims = [v.ndim for v in mapping.values()] + n_vec = sum(d == 1 for d in ndims) # pooled (D,) + n_tok = sum(d == 2 for d in ndims) # per-token (L, D) + n_other = n - n_vec - n_tok + with open(out_path.with_suffix(".pkl.meta"), "w") as mf: + mf.write( + f"entries={n}\npooled_1d={n_vec}\nper_token_2d={n_tok}\nother={n_other}\n" + ) + except Exception: + pass + + if save_as_shelf: + shelf_path = str(pkl_path).replace(".pkl", ".shelf") + pkl_to_shelf(pkl_path=pkl_path, shelf_path=shelf_path) + logger.info(f"Saved as shelf at {shelf_path}") + return pkl_path diff --git a/dpacman/data_tasks/fimo/post_fimo.py b/dpacman/data_tasks/fimo/post_fimo.py index 5ced4791a308f9a510a1a8371c5296af2bc63623..427165953e8d7f1023f234051bbb410d8187070e 100644 --- a/dpacman/data_tasks/fimo/post_fimo.py +++ b/dpacman/data_tasks/fimo/post_fimo.py @@ -1,42 +1,45 @@ #!/usr/bin/env python3 import os -import uuid -import logging from pathlib import Path import multiprocessing as mp - import numpy as np import pandas as pd -import math import json import rootutils import polars as pl from omegaconf import DictConfig from hydra.core.hydra_config import HydraConfig - +import logging from dpacman.data_tasks.fimo.pre_fimo import load_chrom_dna +import rootutils +from dpacman.utils import pylogger root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) -logger = logging.getLogger(__name__) +logger = pylogger.RankedLogger(__name__, rank_zero_only=True) + -def normalize_array(arr: np.ndarray, max_chipseq_score: int=1000, jaspar_boost:int=100) -> np.ndarray: +def normalize_array( + arr: np.ndarray, max_chipseq_score: int = 1000, jaspar_boost: int = 100 +) -> np.ndarray: normalization_factor = max_chipseq_score + jaspar_boost return arr / normalization_factor + def format_sig(sig_vals, decimals=4, atol=0.0, rtol=1e-5): a = np.asarray(sig_vals, dtype=float) - scale = 10.0 ** decimals - thresh = 0.5 / scale # 0.00005 for 4 dp + scale = 10.0**decimals + thresh = 0.5 / scale # 0.00005 for 4 dp # Would display as 0.0000 or 1.0000 at given precision? m0 = np.isclose(a, 0.0, atol=atol, rtol=rtol) | (np.abs(a) <= thresh) m1 = np.isclose(a, 1.0, atol=atol, rtol=rtol) | (np.abs(a - 1.0) <= thresh) - out = np.char.mod(f'%.{decimals}f', a) - out = np.where(m0, '0', out) - out = np.where(m1 & ~m0, '1', out) # don’t overwrite any zeros + out = np.char.mod(f"%.{decimals}f", a) + out = np.where(m0, "0", out) + out = np.where(m1 & ~m0, "1", out) # don’t overwrite any zeros return ",".join(out.tolist()) + def _safe_process(task): """ Returns: @@ -44,26 +47,30 @@ def _safe_process(task): ("err", (chrom, msg, traceback)) on failure """ import traceback as tb + chrom = task[0] try: - out_path = _process_one_chrom_folder(task) # MUST return a path (str/Path) + out_path = _process_one_chrom_folder(task) # MUST return a path (str/Path) return ("ok", str(out_path)) except Exception as e: return ("err", (chrom, repr(e), tb.format_exc())) + def discover_chrom_folders(fimo_out_dir: Path) -> list[str]: return sorted( - name for name in os.listdir(fimo_out_dir) + name + for name in os.listdir(fimo_out_dir) if name.startswith("chrom") and (fimo_out_dir / name / "final.csv").exists() ) + def _process_one_row(row, dna: str, jaspar_boost: int = 100) -> dict: # row order: TR, chrom, cstart, cend, peak_s, peak_e, chipscore, jaspar trname, chrom, cstart, cend, peak_s, peak_e, chipscore, jaspar = row - + # very few chipscores are > 1000. standardize by setting >1000 to a max score - if chipscore>=1000: - chipscore=1000 + if chipscore >= 1000: + chipscore = 1000 seq = dna[cstart:cend] L = len(seq) @@ -74,15 +81,15 @@ def _process_one_row(row, dna: str, jaspar_boost: int = 100) -> dict: pe = peak_e - cstart peak_seq = "" if ps < L and pe > 0: - scores[max(ps, 0):min(pe, L)] = chipscore - peak_seq = seq[max(ps, 0):min(pe, L)] - + scores[max(ps, 0) : min(pe, L)] = chipscore + peak_seq = seq[max(ps, 0) : min(pe, L)] + # JASPAR hits (+jaspar_boost) # only run if the peak is not np.nan total_jaspar = 0 if isinstance(jaspar, str) and jaspar.strip(): for hit in jaspar.split(","): - total_jaspar+=1 + total_jaspar += 1 hs, he = hit.split("-") hs_i = max(int(hs) - cstart, 0) he_i = min(int(he) - cstart, L) @@ -90,9 +97,9 @@ def _process_one_row(row, dna: str, jaspar_boost: int = 100) -> dict: scores[hs_i:he_i] = chipscore + jaspar_boost score_str = ",".join(map(str, [int(x) for x in scores.tolist()])) - #sig_vals = normalize_array(scores.astype(np.float32)) + # sig_vals = normalize_array(scores.astype(np.float32)) # store out to 4 decimal places unless it's 0 - #score_sig = format_sig(sig_vals) + # score_sig = format_sig(sig_vals) return { "chrom": chrom, "tr_name": trname, @@ -103,14 +110,22 @@ def _process_one_row(row, dna: str, jaspar_boost: int = 100) -> dict: "scores": score_str, } + def _process_one_chrom_folder(task) -> pd.DataFrame: """Runs inside a worker process. Reads one chrom’s final.csv, loads DNA once, builds records.""" - chrom_folder, fimo_out_dir_str, json_dir, jaspar_boost, output_parts_folder, keep_fimo_only = task - + ( + chrom_folder, + fimo_out_dir_str, + json_dir, + jaspar_boost, + output_parts_folder, + keep_fimo_only, + ) = task + # make unique logger for this process log_dir = Path(HydraConfig.get().run.dir) / "logs" log_dir.mkdir(parents=True, exist_ok=True) - output_parts_folder.mkdir(parents=True,exist_ok=True) + output_parts_folder.mkdir(parents=True, exist_ok=True) log_file = log_dir / f"fimo_{chrom_folder}.log" wlogger = logging.getLogger(f"fimo_{chrom_folder}") @@ -121,24 +136,30 @@ def _process_one_chrom_folder(task) -> pd.DataFrame: fh = logging.FileHandler(log_file, mode="w", encoding="utf-8") fh.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) wlogger.addHandler(fh) - + fimo_out_dir = Path(fimo_out_dir_str) final_csv = fimo_out_dir / chrom_folder / "final.csv" if not final_csv.exists(): return pd.DataFrame() - usecols = ["TR", "#chrom", "contextStart", "contextEnd", - "ChIPStart", "ChIPEnd", "chipscore", "jaspar"] + usecols = [ + "TR", + "#chrom", + "contextStart", + "contextEnd", + "ChIPStart", + "ChIPEnd", + "chipscore", + "jaspar", + ] df = pd.read_csv(final_csv, usecols=usecols) if df.empty: return pd.DataFrame() - + if keep_fimo_only: logger.info(f"keep_fimo_only=True. Starting with {len(df)} rows.") - df = df.loc[ - ~df["jaspar"].isna() - ].reset_index(drop=True) + df = df.loc[~df["jaspar"].isna()].reset_index(drop=True) logger.info(f"After keeping fimo hits only: {len(df)} rows remain.") # Normalize dtypes up-front @@ -148,11 +169,13 @@ def _process_one_chrom_folder(task) -> pd.DataFrame: chrom = df["#chrom"].iloc[0] dna_cache = {} - dna = load_chrom_dna(str(chrom), dna_cache, json_dir).upper() # just capitalize it for training + dna = load_chrom_dna( + str(chrom), dna_cache, json_dir + ).upper() # just capitalize it for training wlogger.info(f"Loaded DNA for {chrom}, length {len(dna)}") records = [] - + # rename to make processing easier rename = { "#chrom": "chrom", @@ -163,7 +186,7 @@ def _process_one_chrom_folder(task) -> pd.DataFrame: "TR": "tr_name", } df = df.rename(columns=rename) - + # (Optional) ensure numeric dtypes; will raise if non-numeric for col in ["cstart", "cend", "peak_s", "peak_e", "chipscore"]: df[col] = pd.to_numeric(df[col], errors="raise") @@ -174,9 +197,18 @@ def _process_one_chrom_folder(task) -> pd.DataFrame: for i, row in enumerate(df.itertuples(index=False), start=1): records.append( _process_one_row( - (row.tr_name, row.chrom, int(row.cstart), int(row.cend), - int(row.peak_s), int(row.peak_e), int(row.chipscore), row.jaspar), - dna, jaspar_boost + ( + row.tr_name, + row.chrom, + int(row.cstart), + int(row.cend), + int(row.peak_s), + int(row.peak_e), + int(row.chipscore), + row.jaspar, + ), + dna, + jaspar_boost, ) ) @@ -185,18 +217,27 @@ def _process_one_chrom_folder(task) -> pd.DataFrame: if decile > last_decile: last_decile = decile wlogger.info("Progress: %d%% (%d/%d)", decile * 10, i, total) - + wlogger.info(f"Completed processing {len(records)} rows for {chrom_folder}") - + # make into a DataFrame and save records_df = pd.DataFrame.from_records(records) savepath = output_parts_folder / f"{chrom_folder}_processed.csv" records_df.to_csv(savepath, index=False) wlogger.info(f"Saved records to {savepath}") - + return savepath -def build_dataset_fast_mp(fimo_out_dir: Path, json_dir: str, debug: bool, max_workers: int | None, jaspar_boost: int = 100, output_parts_folder: str = None, keep_fimo_only: bool=False) -> pd.DataFrame: + +def build_dataset_fast_mp( + fimo_out_dir: Path, + json_dir: str, + debug: bool, + max_workers: int | None, + jaspar_boost: int = 100, + output_parts_folder: str = None, + keep_fimo_only: bool = False, +) -> pd.DataFrame: """ Multiprocessing to build final dataset across chromosomes """ @@ -210,8 +251,17 @@ def build_dataset_fast_mp(fimo_out_dir: Path, json_dir: str, debug: bool, max_wo chrom_folders = [c for c in chrom_folders if c == "chromY"] or chrom_folders[:1] logger.info(f"DEBUG MODE: considering {chrom_folders[0]} only") - tasks = [(cf, str(fimo_out_dir), json_dir, jaspar_boost, output_parts_folder, keep_fimo_only) - for cf in chrom_folders] + tasks = [ + ( + cf, + str(fimo_out_dir), + json_dir, + jaspar_boost, + output_parts_folder, + keep_fimo_only, + ) + for cf in chrom_folders + ] def _collect(status, payload, good_paths, errs): if status == "ok": @@ -239,7 +289,9 @@ def build_dataset_fast_mp(fimo_out_dir: Path, json_dir: str, debug: bool, max_wo good_paths: list[Path] = [] errs: list[tuple[str, str, str]] = [] with mp.Pool(processes=procs, maxtasksperchild=10) as pool: - for status, payload in pool.imap_unordered(_safe_process, tasks, chunksize=1): + for status, payload in pool.imap_unordered( + _safe_process, tasks, chunksize=1 + ): _collect(status, payload, good_paths, errs) if errs: @@ -249,29 +301,35 @@ def build_dataset_fast_mp(fimo_out_dir: Path, json_dir: str, debug: bool, max_wo return [str(p) for p in good_paths] -def dedup_trname_peakseq_weighted(lf: pl.LazyFrame, seed: int = 42, outdir: str | None = None) -> pl.LazyFrame: + +def dedup_trname_peakseq_weighted( + lf: pl.LazyFrame, seed: int = 42, outdir: str | None = None +) -> pl.LazyFrame: """ Remove duplicate pairings of TR + peak sequence, but keep the distribution of chromosomes as best as possible. - Use a seed so the results are reproducible + Use a seed so the results are reproducible """ # Normalize key dtypes - lf = lf.with_columns([ - pl.col("chrom").cast(pl.Utf8), - pl.col("tr_name").cast(pl.Utf8), - pl.col("peak_sequence").fill_null("").cast(pl.Utf8), - ]) + lf = lf.with_columns( + [ + pl.col("chrom").cast(pl.Utf8), + pl.col("tr_name").cast(pl.Utf8), + pl.col("peak_sequence").fill_null("").cast(pl.Utf8), + ] + ) # --- BEFORE: counts/ratios (materialize tiny table) pre_df = ( - lf.group_by("chrom").len() + lf.group_by("chrom") + .len() .with_columns((pl.col("len") / pl.col("len").sum()).alias("pre_ratio")) .sort("chrom") .collect() ) # Expected #groups (must equal result rows) - exp_groups = lf.select(["tr_name","peak_sequence"]).unique().collect().height - logger.info(f"Expected groups: {exp_groups}") + exp_groups = lf.select(["tr_name", "peak_sequence"]).unique().collect().height + logger.info(f"Expected groups: {exp_groups}") # Tiny weights table back to lazy pre_lf = pre_df.lazy().select(["chrom", "pre_ratio"]).rename({"pre_ratio": "w"}) @@ -280,52 +338,96 @@ def dedup_trname_peakseq_weighted(lf: pl.LazyFrame, seed: int = 42, outdir: str TWO64 = 18446744073709551616.0 eps = 1e-12 - h_expr = pl.concat_str( - [pl.lit(f"seed:{seed}"), pl.col("tr_name"), pl.col("peak_sequence"), pl.col("chrom")], - separator="|", - ).hash().cast(pl.UInt64) + h_expr = ( + pl.concat_str( + [ + pl.lit(f"seed:{seed}"), + pl.col("tr_name"), + pl.col("peak_sequence"), + pl.col("chrom"), + ], + separator="|", + ) + .hash() + .cast(pl.UInt64) + ) u_expr = (h_expr.cast(pl.Float64) + 1.0) / pl.lit(TWO64) - u_expr = pl.when(u_expr < eps).then(eps).when(u_expr > 1 - eps).then(1 - eps).otherwise(u_expr) + u_expr = ( + pl.when(u_expr < eps) + .then(eps) + .when(u_expr > 1 - eps) + .then(1 - eps) + .otherwise(u_expr) + ) - logw_expr = pl.when(pl.col("w").is_null() | (pl.col("w") <= 0)).then(eps).otherwise(pl.col("w")).log() + logw_expr = ( + pl.when(pl.col("w").is_null() | (pl.col("w") <= 0)) + .then(eps) + .otherwise(pl.col("w")) + .log() + ) gumbel_expr = -(-u_expr.log()).log() score_expr = (logw_expr + gumbel_expr).alias("_score") - hash_expr = h_expr.alias("_h") + hash_expr = h_expr.alias("_h") # Attach weights & scores, globally sort, then unique on the keys (keep first) lf_sorted = ( lf.join(pre_lf, on="chrom", how="left") - .with_columns([score_expr, hash_expr]) - .sort(["_score","_h"], descending=[True, False]) + .with_columns([score_expr, hash_expr]) + .sort(["_score", "_h"], descending=[True, False]) ) - lf_sel = lf_sorted.unique(subset=["tr_name","peak_sequence"], keep="first") \ - .drop(["w","_score","_h"]) + lf_sel = lf_sorted.unique(subset=["tr_name", "peak_sequence"], keep="first").drop( + ["w", "_score", "_h"] + ) # --- AFTER: counts/ratios + save post_df = ( - lf_sel.group_by("chrom").len() - .with_columns((pl.col("len") / pl.col("len").sum()).alias("post_ratio")) - .sort("chrom") - .collect(streaming=True) - ) - compare_df = (pre_df.select(["chrom","len","pre_ratio"]).rename({"len":"pre_n"}) - .join(post_df.select(["chrom","len","post_ratio"]).rename({"len":"post_n"}), - on="chrom", how="full") - .fill_null(0) - .with_columns((pl.col("post_ratio") - pl.col("pre_ratio")).abs().alias("abs_delta"), - (100*(pl.col("post_ratio") - pl.col("pre_ratio"))/pl.col("pre_ratio")).abs().alias("pcnt_delta")) - .sort("chrom") - ).to_pandas().drop(columns=["chrom_right"]) + lf_sel.group_by("chrom") + .len() + .with_columns((pl.col("len") / pl.col("len").sum()).alias("post_ratio")) + .sort("chrom") + .collect(streaming=True) + ) + compare_df = ( + ( + pre_df.select(["chrom", "len", "pre_ratio"]) + .rename({"len": "pre_n"}) + .join( + post_df.select(["chrom", "len", "post_ratio"]).rename( + {"len": "post_n"} + ), + on="chrom", + how="full", + ) + .fill_null(0) + .with_columns( + (pl.col("post_ratio") - pl.col("pre_ratio")).abs().alias("abs_delta"), + ( + 100 + * (pl.col("post_ratio") - pl.col("pre_ratio")) + / pl.col("pre_ratio") + ) + .abs() + .alias("pcnt_delta"), + ) + .sort("chrom") + ) + .to_pandas() + .drop(columns=["chrom_right"]) + ) # --- Sanity: must keep exactly one per group got_rows = lf_sel.select(pl.len()).collect()["len"][0] if got_rows != exp_groups: # optional: raise or just log - logger.warning(f"Dedup cardinality mismatch: expected {exp_groups}, got {got_rows}") + logger.warning( + f"Dedup cardinality mismatch: expected {exp_groups}, got {got_rows}" + ) return lf_sel, compare_df + def write_map(lf: pl.LazyFrame, out_path: str, key: str, val: str, outname: str): """ Write the ID maps we created, spanning all the data. Will be called for: @@ -336,23 +438,19 @@ def write_map(lf: pl.LazyFrame, out_path: str, key: str, val: str, outname: str) maps_dir = Path(out_path).parent / "maps" maps_dir.mkdir(parents=True, exist_ok=True) - df = ( - lf.select([pl.col(key), pl.col(val)]) - .unique() - .collect(streaming=True) - ) + df = lf.select([pl.col(key), pl.col(val)]).unique().collect(streaming=True) mapping = dict(zip(df[key].to_list(), df[val].to_list())) with open(maps_dir / outname, "w") as f: json.dump(mapping, f, indent=2) - + def combine_processed_with_polars( paths_to_processed_dfs: list[str], - idmap_path: str, # TSV with columns: From, Entry, Sequence - out_path: str, # e.g., "processed_out.parquet" or ".csv" + idmap_path: str, # TSV with columns: From, Entry, Sequence + out_path: str, # e.g., "processed_out.parquet" or ".csv" max_protein_len: int = None, check_violations: bool = False, - seeds: list = [0] + seeds: list = [0], ): if not paths_to_processed_dfs: logger.info("No records produced; nothing to write.") @@ -361,76 +459,80 @@ def combine_processed_with_polars( # 1) Scan each CSV and normalize dtypes BEFORE concat lfs = [] for p in paths_to_processed_dfs: - lf_i = ( - pl.scan_csv(p) # don't use infer_schema_length=0 here - .with_columns([ + lf_i = pl.scan_csv(p).with_columns( # don't use infer_schema_length=0 here + [ pl.col("chrom").cast(pl.Utf8), pl.col("tr_name").cast(pl.Utf8), pl.col("dna_sequence").cast(pl.Utf8), pl.col("peak_sequence").cast(pl.Utf8), pl.col("scores").cast(pl.Utf8), - pl.col("chipscore").cast(pl.Float64), # robust (int -> float OK) + pl.col("chipscore").cast(pl.Float64), # robust (int -> float OK) pl.col("total_jaspar_hits").cast(pl.Int64), - ]) + ] ) lfs.append(lf_i) # 2) Now concat; schemas match lf_og = pl.concat(lfs, how="vertical") - + # Read idmap to get list of unmapped TRs, those with no sequence - idmap = ( - pl.read_csv(idmap_path, separator="\t", columns=["From", "Entry", "Sequence"]) - .rename({"From": "tr_name", "Entry": "tr_uniprot", "Sequence": "tr_sequence"}) - ) + idmap = pl.read_csv( + idmap_path, separator="\t", columns=["From", "Entry", "Sequence"] + ).rename({"From": "tr_name", "Entry": "tr_uniprot", "Sequence": "tr_sequence"}) idmap = idmap.with_columns( - pl.col("tr_sequence").map_elements(lambda x: len(x), return_dtype=pl.Int64).alias("tr_len") + pl.col("tr_sequence") + .map_elements(lambda x: len(x), return_dtype=pl.Int64) + .alias("tr_len") ) if max_protein_len is not None: - idmap = idmap.filter( - pl.col("tr_len")<=max_protein_len - ) + idmap = idmap.filter(pl.col("tr_len") <= max_protein_len) logger.info(f"Filtered valid TRs to only those with len <= {max_protein_len}") success_trs = list(idmap["tr_name"].unique()) logger.info(f"Total valid TRs: {len(success_trs)}") - + # Filter lf to only have "success-TRs" lf_og = lf_og.filter(pl.col("tr_name").is_in(success_trs)) - + # 2) We COULD drop duplicate occurrences of tr_name and peak_sequence, because these are the same peak. But here we're showing them in different contexts - # Instead, let's drop duplicate occurrences of the same tr_name and dna_sequence, because these are duplicate datapoints. - lf_out = None # the last one will be used to save the example file - out_path = str(out_path) + # Instead, let's drop duplicate occurrences of the same tr_name and dna_sequence, because these are duplicate datapoints. + lf_out = None # the last one will be used to save the example file + out_path = str(out_path) Path(out_path).parent.mkdir(parents=True, exist_ok=True) - + lf_og = lf_og.join(idmap.lazy(), on="tr_name", how="left") logger.info(f"Merged in UniProt IDs and TR sequences from UniProt ID mappping") - + # 4) Per-chromosome unique peak index and peak_id # (dense rank over peak_sequence per chrom; if you require "first-appearance" order, # see the note below for an alternate approach.) # Ensure types - lf_og = lf_og.with_columns([ - pl.col("dna_sequence").cast(pl.Utf8), - pl.col("tr_sequence").cast(pl.Utf8), - pl.col("peak_sequence").cast(pl.Utf8) - ]) - - # set chrom+ peak sequence IDs - lf_og = lf_og.with_columns([ - pl.col("peak_sequence").fill_null("").alias("peak_sequence"), - pl.col("chrom").cast(pl.Utf8), - ]) + lf_og = lf_og.with_columns( + [ + pl.col("dna_sequence").cast(pl.Utf8), + pl.col("tr_sequence").cast(pl.Utf8), + pl.col("peak_sequence").cast(pl.Utf8), + ] + ) + + # set chrom+ peak sequence IDs + lf_og = lf_og.with_columns( + [ + pl.col("peak_sequence").fill_null("").alias("peak_sequence"), + pl.col("chrom").cast(pl.Utf8), + ] + ) lf_og = lf_og.with_columns( pl.col("peak_sequence") - .rank(method="dense") # 1,2,3,... per group + .rank(method="dense") # 1,2,3,... per group .over("chrom") .cast(pl.Int64) .alias("chrom_peak_idx") ) lf_og = lf_og.with_columns( - pl.format("chr{}_peak{}", pl.col("chrom"), pl.col("chrom_peak_idx")).alias("chrpeak_id") + pl.format("chr{}_peak{}", pl.col("chrom"), pl.col("chrom_peak_idx")).alias( + "chrpeak_id" + ) ) logger.info(f"Assigned unique chrpeak_ids per chromosome based on peak_sequence") @@ -438,90 +540,154 @@ def combine_processed_with_polars( # (do this by creating small maps with unique(..., maintain_order=True) and joining) # Sequence-based IDs without any joins lf_og = ( - lf_og.with_columns([ - pl.col("dna_sequence").rank(method="dense").cast(pl.Int64).alias("dna_idx"), - pl.col("tr_sequence").rank(method="dense").cast(pl.Int64).alias("tr_idx"), - pl.col("peak_sequence").rank(method="dense").cast(pl.Int64).alias("peak_idx"), - ]) - .with_columns([ - pl.format("dnaseq{}", pl.col("dna_idx")).alias("dna_seqid"), - pl.format("trseq{}", pl.col("tr_idx")).alias("tr_seqid"), - pl.format("peakseq{}", pl.col("peak_idx")).alias("peak_seqid"), - ]) + lf_og.with_columns( + [ + pl.col("dna_sequence") + .rank(method="dense") + .cast(pl.Int64) + .alias("dna_idx"), + pl.col("tr_sequence") + .rank(method="dense") + .cast(pl.Int64) + .alias("tr_idx"), + pl.col("peak_sequence") + .rank(method="dense") + .cast(pl.Int64) + .alias("peak_idx"), + ] + ) + .with_columns( + [ + pl.format("dnaseq{}", pl.col("dna_idx")).alias("dna_seqid"), + pl.format("trseq{}", pl.col("tr_idx")).alias("tr_seqid"), + pl.format("peakseq{}", pl.col("peak_idx")).alias("peak_seqid"), + ] + ) .drop(["dna_idx", "tr_idx", "peak_idx"]) ) logger.info(f"Assigned unique dna IDs, transcriptional regulator IDs, and peak IDs") # Final ID (will never be None now) lf_og = lf_og.with_columns( - pl.concat_str([pl.col("tr_seqid"), pl.lit("_"), pl.col("dna_seqid")], ignore_nulls=False) - .alias("ID") + pl.concat_str( + [pl.col("tr_seqid"), pl.lit("_"), pl.col("dna_seqid")], ignore_nulls=False + ).alias("ID") ) - - # Write the maps + + # Write the maps # call it for each mapping - write_map(lf_og, out_path=out_path, val="tr_sequence", key="tr_seqid", outname="tr_seqid_to_tr_sequence.json") - write_map(lf_og, out_path=out_path, val="peak_sequence", key="peak_seqid", outname="peak_seqid_to_peak_sequence.json") - write_map(lf_og, out_path=out_path, val="dna_sequence", key="dna_seqid", outname="dna_seqid_to_dna_sequence.json") - + write_map( + lf_og, + out_path=out_path, + val="tr_sequence", + key="tr_seqid", + outname="tr_seqid_to_tr_sequence.json", + ) + write_map( + lf_og, + out_path=out_path, + val="peak_sequence", + key="peak_seqid", + outname="peak_seqid_to_peak_sequence.json", + ) + write_map( + lf_og, + out_path=out_path, + val="dna_sequence", + key="dna_seqid", + outname="dna_seqid_to_dna_sequence.json", + ) + for seed in seeds: - # edit out path to include seed + # edit out path to include seed if "." in out_path: - out_path_full = out_path[0:out_path.rindex(".")] + f"_seed{seed}" + out_path[out_path.rindex("."):] - else: + out_path_full = ( + out_path[0 : out_path.rindex(".")] + + f"_seed{seed}" + + out_path[out_path.rindex(".") :] + ) + else: out_path_full = out_path + f"_seed{seed}.parquet" - + lf, compare_df = dedup_trname_peakseq_weighted(lf_og, seed=seed) - logger.info(f"Dropped duplicate examples of tr_name + peak_sequence. Maintained chrom distribution with weighted random sampling (seed={seed}).") - + logger.info( + f"Dropped duplicate examples of tr_name + peak_sequence. Maintained chrom distribution with weighted random sampling (seed={seed})." + ) + # Save comparison df. Annotate with debug if it's a debug run - compare_df_path = str(Path(out_path).parent/f"chrom_ratio_compare_seed{seed}.csv") - if "debug" in out_path: compare_df_path = compare_df_path.replace(".csv", "_debug.csv") + compare_df_path = str( + Path(out_path).parent / f"chrom_ratio_compare_seed{seed}.csv" + ) + if "debug" in out_path: + compare_df_path = compare_df_path.replace(".csv", "_debug.csv") compare_df.to_csv(compare_df_path, index=False) - + # 3) Join small idmap (read eagerly; it’s tiny) lf = lf.join(idmap.lazy(), on="tr_name", how="left") logger.info(f"Merged in UniProt IDs and TR sequences from UniProt ID mappping") - logger.info(f"Applied dna_sequence and tr_sequence IDs to main table") - + # Each sequence maps to exactly one id viol1 = ( - lf.select("dna_sequence", "dna_seqid").unique() - .group_by("dna_sequence").agg(pl.n_unique("dna_seqid").alias("n_ids")) + lf.select("dna_sequence", "dna_seqid") + .unique() + .group_by("dna_sequence") + .agg(pl.n_unique("dna_seqid").alias("n_ids")) .filter(pl.col("n_ids") > 1) .collect() ) # Each id maps to exactly one sequence viol2 = ( - lf.select("dna_sequence", "dna_seqid").unique() - .group_by("dna_seqid").agg(pl.n_unique("dna_sequence").alias("n_seqs")) + lf.select("dna_sequence", "dna_seqid") + .unique() + .group_by("dna_seqid") + .agg(pl.n_unique("dna_sequence").alias("n_seqs")) .filter(pl.col("n_seqs") > 1) .collect() ) - logger.info("viol1 rows (seq→>1 id): %d; viol2 rows (id→>1 seq): %d", viol1.height, viol2.height) + logger.info( + "viol1 rows (seq→>1 id): %d; viol2 rows (id→>1 seq): %d", + viol1.height, + viol2.height, + ) # No NULLs - nulls = lf.select([ - pl.col("dna_seqid").is_null().sum().alias("null_dna_seqid"), - pl.col("tr_seqid").is_null().sum().alias("null_tr_seqid"), - pl.col("ID").is_null().sum().alias("null_ID"), - ]).collect() + nulls = lf.select( + [ + pl.col("dna_seqid").is_null().sum().alias("null_dna_seqid"), + pl.col("tr_seqid").is_null().sum().alias("null_tr_seqid"), + pl.col("ID").is_null().sum().alias("null_ID"), + ] + ).collect() logger.info("NULL counts:\n%s", nulls) - + # 6) Final column selection cols = [ - "ID", "tr_seqid", "dna_seqid", "peak_seqid", "chrpeak_id", "tr_name", "chipscore", "total_jaspar_hits", - "dna_sequence", "tr_sequence", "scores" + "ID", + "tr_seqid", + "dna_seqid", + "peak_seqid", + "chrpeak_id", + "tr_name", + "chipscore", + "total_jaspar_hits", + "dna_sequence", + "tr_sequence", + "scores", ] lf_out = lf.select(cols) - #n_rows = lf_out.select(pl.len().alias("rows")).collect(streaming=True)["rows"][0] + # n_rows = lf_out.select(pl.len().alias("rows")).collect(streaming=True)["rows"][0] logger.info(f"Selected final columns") # 7) Write streaming to disk if out_path_full.lower().endswith(".parquet"): - lf_out.sink_parquet(out_path_full, compression="zstd", statistics=True, row_group_size=128_000) + lf_out.sink_parquet( + out_path_full, + compression="zstd", + statistics=True, + row_group_size=128_000, + ) logger.info(f"Wrote parquet file to {out_path_full}") elif out_path_full.lower().endswith(".csv"): # NOTE: collect(streaming=True) still returns an in-memory DataFrame; @@ -530,16 +696,23 @@ def combine_processed_with_polars( logger.info(f"Wrote csv file to {out_path_full}") else: # default to Parquet if no/unknown extension - lf_out.sink_parquet(out_path_full + ".parquet", compression="zstd", statistics=True) + lf_out.sink_parquet( + out_path_full + ".parquet", compression="zstd", statistics=True + ) logger.info(f"Wrote parquet file to {out_path_full}") # Save the FIRST 1000 rows to CSV (streaming-friendly) df_first = lf_out.limit(1000).collect(streaming=True) - example_out_path = Path(root) / "dpacman/data_files/processed/remap/examples" / "example1000_remap2022_crm_fimo_output_q_processed.csv" + example_out_path = ( + Path(root) + / "dpacman/data_files/processed/remap/examples" + / "example1000_remap2022_crm_fimo_output_q_processed.csv" + ) df_first.write_csv(example_out_path) logger.info(f"Wrote first 1000 rows to {example_out_path} as an example") -# FIMO check + +# FIMO check def get_reverse_complement(s): """ Returns 5' to 3' sequence of the reverse complement @@ -551,156 +724,184 @@ def get_reverse_complement(s): "c": "g", "t": "a", "g": "c", - "A":"T", + "A": "T", "C": "G", "T": "A", "G": "C", "n": "n", - "N": "N" + "N": "N", } for c in chars: recon += [rev_map[c]] - + recon = "".join(recon) return recon[::-1] + def extract_jaspar_motifs(row, reverse_complement=False): s = row["scores"] s = [int(x) for x in s.split(",")] n_motifs = row["total_jaspar_hits"] - if n_motifs==0: + if n_motifs == 0: return "" chipscore = row["chipscore"] dna_seq = row["dna_sequence"] if reverse_complement: dna_seq = row["dna_sequence_rc"] - jaspar_indices = [i for i in list(range(len(s))) if s[i]>chipscore] + jaspar_indices = [i for i in list(range(len(s))) if s[i] > chipscore] pred_motif = "" - for i in list(range(jaspar_indices[0], jaspar_indices[-1]+1)): - if not(i in jaspar_indices): + for i in list(range(jaspar_indices[0], jaspar_indices[-1] + 1)): + if not (i in jaspar_indices): pred_motif += "-" else: pred_motif += dna_seq[i] - + return pred_motif + def clean_idmap(idmap_path): """ - The raw ID Map from UniProt returned multiple results. + The raw ID Map from UniProt returned multiple results. We went to ReMap and wrote down what the right mappings are in these cases. """ - + manual_map = { - "BACH1": "O14867", - "BAP1": "Q92560", - "BDP1": "A6H8Y1", - "BRF1": "Q92994", - "CUX1": "Q13948", - "DDX21": "Q9NR30", - "ERG": "P11308", - "HBP1": "O60381", - "KLF14": "Q8TD94", - "MED1": "Q15648", - "MED25": "Q71SY5", - "MGA": "Q8IWI9", - "NRF1": "Q16656", - "PAF1": "Q8N7H5", - "PDX1": "P52945", - "RBP2": "P50120", - "RLF": "Q13129", - "SP1": "P08047", - "SPIN1": "Q9Y657", - "STAG1": "Q8WVM7", - "TAF15": "Q92804", - "TCF3": "P15923", + "BACH1": "O14867", + "BAP1": "Q92560", + "BDP1": "A6H8Y1", + "BRF1": "Q92994", + "CUX1": "Q13948", + "DDX21": "Q9NR30", + "ERG": "P11308", + "HBP1": "O60381", + "KLF14": "Q8TD94", + "MED1": "Q15648", + "MED25": "Q71SY5", + "MGA": "Q8IWI9", + "NRF1": "Q16656", + "PAF1": "Q8N7H5", + "PDX1": "P52945", + "RBP2": "P50120", + "RLF": "Q13129", + "SP1": "P08047", + "SPIN1": "Q9Y657", + "STAG1": "Q8WVM7", + "TAF15": "Q92804", + "TCF3": "P15923", "ZFP36": "P26651", "EVI1": "Q03112", - "MCM2": "P49736" + "MCM2": "P49736", } idmap = pd.read_csv(idmap_path, sep="\t") - idmap["Remap_Entry"] = idmap.apply(lambda row: row["Entry"] if not(row["From"] in manual_map) else manual_map[row["From"]], axis=1) - idmap_remapped = idmap.loc[ - idmap["Entry"]==idmap["Remap_Entry"] - ].reset_index(drop=True).drop(columns=["Remap_Entry"]) - - assert len(idmap_remapped)==len(idmap_remapped["From"].unique()) - logger.info(f"Total transcriptional regulators successfully mapped in UniProt: {len(idmap_remapped)}") - - clean_idmap_path = Path(root)/"dpacman/data_files/processed/remap/idmapping_reviewed_true_processed_2025_08_11.tsv" + idmap["Remap_Entry"] = idmap.apply( + lambda row: ( + row["Entry"] if not (row["From"] in manual_map) else manual_map[row["From"]] + ), + axis=1, + ) + idmap_remapped = ( + idmap.loc[idmap["Entry"] == idmap["Remap_Entry"]] + .reset_index(drop=True) + .drop(columns=["Remap_Entry"]) + ) + + assert len(idmap_remapped) == len(idmap_remapped["From"].unique()) + logger.info( + f"Total transcriptional regulators successfully mapped in UniProt: {len(idmap_remapped)}" + ) + + clean_idmap_path = ( + Path(root) + / "dpacman/data_files/processed/remap/idmapping_reviewed_true_processed_2025_08_11.tsv" + ) idmap_remapped.to_csv(clean_idmap_path, sep="\t") return clean_idmap_path -def debug_fimo_check(path_to_chrom_fimo, path_to_processed_chrom, chrom="Y", json_dir=""): + +def debug_fimo_check( + path_to_chrom_fimo, path_to_processed_chrom, chrom="Y", json_dir="" +): """ - Make sure we are properly extracting fimo sequences. + Make sure we are properly extracting fimo sequences. """ processed = pd.read_csv(path_to_processed_chrom) - processed["pred_motif_string"] = processed.apply(lambda row: extract_jaspar_motifs(row), axis=1) - processed["dna_sequence_rc"] = processed["dna_sequence"].apply(lambda x: get_reverse_complement(x)) - processed["pred_motif_string_rc"] = processed.apply(lambda row: extract_jaspar_motifs(row, reverse_complement=True), axis=1) + processed["pred_motif_string"] = processed.apply( + lambda row: extract_jaspar_motifs(row), axis=1 + ) + processed["dna_sequence_rc"] = processed["dna_sequence"].apply( + lambda x: get_reverse_complement(x) + ) + processed["pred_motif_string_rc"] = processed.apply( + lambda row: extract_jaspar_motifs(row, reverse_complement=True), axis=1 + ) processed_trs = processed["tr_name"].unique().tolist() fimo = pd.read_csv(path_to_chrom_fimo) - fimo["input_tr"] = fimo["sequence_name"].str.split("_",expand=True)[2] + fimo["input_tr"] = fimo["sequence_name"].str.split("_", expand=True)[2] fimo_valid = fimo.loc[ - (fimo["motif_alt_id"]==fimo["input_tr"]) & - (fimo["motif_alt_id"].isin(processed_trs)) - ].reset_index(drop=True) + (fimo["motif_alt_id"] == fimo["input_tr"]) + & (fimo["motif_alt_id"].isin(processed_trs)) + ].reset_index(drop=True) logger.info(f"Total valid FIMO matches: {len(fimo_valid)}") - logger.info(f"Total transcriptional regulators being considered: {len(processed_trs)}") - - # Load DNA + logger.info( + f"Total transcriptional regulators being considered: {len(processed_trs)}" + ) + + # Load DNA cache_debug = load_chrom_dna(chrom, {}, json_dir=json_dir) - + # Randomly select a positive and negative row to test - pos_row = fimo_valid.loc[fimo_valid["strand"]=="+"].sample(n=1, random_state=44) - neg_row = fimo_valid.loc[fimo_valid["strand"]=="-"].sample(n=1, random_state=44) + pos_row = fimo_valid.loc[fimo_valid["strand"] == "+"].sample(n=1, random_state=44) + neg_row = fimo_valid.loc[fimo_valid["strand"] == "-"].sample(n=1, random_state=44) - # Iterate through the rows + # Iterate through the rows for row in [pos_row, neg_row]: indices = [int(x) for x in row["sequence_name"].item().split("_")[-2::]] chipseq_start, chipseq_end = indices - strand = row['strand'].item() + strand = row["strand"].item() logger.info(f"ChIPseq start: {chipseq_start}, ChIPseq end: {chipseq_end}") logger.info(f"Strand: {strand}") - motif_start = chipseq_start + int(row["start"].item()) -1 + motif_start = chipseq_start + int(row["start"].item()) - 1 motif_end = chipseq_start + int(row["stop"].item()) - motif = row['matched_sequence'].item() + motif = row["matched_sequence"].item() full_seq = cache_debug[chipseq_start:chipseq_end].upper() logger.info(f"Full sequence: {full_seq}") - logger.info(f"Full sequence reverse complement: {get_reverse_complement(full_seq)}") + logger.info( + f"Full sequence reverse complement: {get_reverse_complement(full_seq)}" + ) our_motif = cache_debug[motif_start:motif_end].upper() - if strand=="+": + if strand == "+": logger.info(f"True motif found by FIMO: {motif}") logger.info(f"Extracted motif on our end: {our_motif}") logger.info(f"Correct extraction: {motif==our_motif}") - + matching_rows = processed.loc[ - (processed["dna_sequence"].str.contains(full_seq)) & - (processed["pred_motif_string"].str.contains(our_motif)) + (processed["dna_sequence"].str.contains(full_seq)) + & (processed["pred_motif_string"].str.contains(our_motif)) ] - if strand=="-": + if strand == "-": our_motif_rc = get_reverse_complement(our_motif) - + logger.info(f"True motif found by FIMO: {motif}") logger.info(f"Extracted motif on our end: {our_motif_rc}") logger.info(f"Correct extraction: {motif==our_motif_rc}") logger.info(f"Motif that will appear in the forward sequence: {our_motif}") matching_rows = processed.loc[ - (processed["dna_sequence"].str.contains(full_seq)) & - (processed["pred_motif_string"].str.contains(our_motif)) + (processed["dna_sequence"].str.contains(full_seq)) + & (processed["pred_motif_string"].str.contains(our_motif)) ] - + # Now find if there are rows with the same TR, and the same DNA sequence and motif matching_row_trs = sorted(matching_rows["tr_name"].unique().tolist()) - expected_tr = row['motif_alt_id'].item() + expected_tr = row["motif_alt_id"].item() logger.info(f"TR from selected row: {expected_tr}") logger.info(f"TRs with same motif: {','.join(matching_row_trs)}") logger.info(f"Expected TR in list: {expected_tr in matching_row_trs}") + def debug_remap_check(remap_path, path_to_processed_chrom, chrom="Y", json_dir=""): """ For debugging mode: pick a random row from processed remap. make sure the sequence matches the one we're getting here. @@ -709,27 +910,36 @@ def debug_remap_check(remap_path, path_to_processed_chrom, chrom="Y", json_dir=" remap["ChIPStart"] = remap["ChIPStart"].astype(int) remap["ChIPEnd"] = remap["ChIPEnd"].astype(int) - row = remap.loc[remap["#chrom"]=="Y"].sample(n=1, random_state=42) + row = remap.loc[remap["#chrom"] == "Y"].sample(n=1, random_state=42) start, end = row["ChIPStart"].item(), row["ChIPEnd"].item() - + cache_debug = load_chrom_dna(chrom, {}, json_dir=json_dir) test_seq = cache_debug[start:end].upper() - logger.info(f"Randomly sampled sequence ({len(test_seq)} nucleotides), chrY {start}:{end}\n\tsequence: {test_seq}") - should_find = remap.loc[ - (remap["ChIPStart"]==start) & - (remap["ChIPEnd"]==end) - ]["TR"].unique().tolist() - logger.info(f"Expect to find {len(should_find)} TRs: {', '.join(sorted(should_find))}") + logger.info( + f"Randomly sampled sequence ({len(test_seq)} nucleotides), chrY {start}:{end}\n\tsequence: {test_seq}" + ) + should_find = ( + remap.loc[(remap["ChIPStart"] == start) & (remap["ChIPEnd"] == end)]["TR"] + .unique() + .tolist() + ) + logger.info( + f"Expect to find {len(should_find)} TRs: {', '.join(sorted(should_find))}" + ) processed = pd.read_csv(path_to_processed_chrom) - did_find = processed.loc[ - processed["peak_sequence"]==test_seq - ]["tr_name"].unique().tolist() - logger.info(f"Looked up same sequence in processed chrY file.\nFound TRs: {', '.join(sorted(did_find))}") + did_find = ( + processed.loc[processed["peak_sequence"] == test_seq]["tr_name"] + .unique() + .tolist() + ) + logger.info( + f"Looked up same sequence in processed chrY file.\nFound TRs: {', '.join(sorted(did_find))}" + ) logger.info(f"found==expected: {did_find==should_find}") - + def main(cfg: DictConfig): debug = bool(cfg.data_task.debug) json_dir = cfg.data_task.json_dir @@ -740,13 +950,16 @@ def main(cfg: DictConfig): logger.info(f"Debug: {debug}") logger.info(f"Reading per-chrom final.csv under: {fimo_out_dir}") - + # process the idmap - idmap_path=Path(root) / cfg.data_task.idmap_path + idmap_path = Path(root) / cfg.data_task.idmap_path clean_idmap_path = clean_idmap(idmap_path) # If we don't have temp files to process - if not(os.path.exists(output_parts_folder)) or (os.path.exists(output_parts_folder) and len(os.listdir(output_parts_folder))<24): + if not (os.path.exists(output_parts_folder)) or ( + os.path.exists(output_parts_folder) + and len(os.listdir(output_parts_folder)) < 24 + ): paths_to_processed_dfs = build_dataset_fast_mp( fimo_out_dir=fimo_out_dir, json_dir=json_dir, @@ -754,31 +967,44 @@ def main(cfg: DictConfig): max_workers=max_workers, jaspar_boost=cfg.data_task.jaspar_boost, output_parts_folder=output_parts_folder, - keep_fimo_only=cfg.data_task.keep_fimo_only + keep_fimo_only=cfg.data_task.keep_fimo_only, ) else: - paths_to_processed_dfs = [output_parts_folder/x for x in os.listdir(output_parts_folder)] if output_parts_folder.exists() else [] - + paths_to_processed_dfs = ( + [output_parts_folder / x for x in os.listdir(output_parts_folder)] + if output_parts_folder.exists() + else [] + ) + # Debug methods: (1) make sure our peak sequences correspond to remap, (2) make sure our FIMO sequences correspond to FIMO results out_path = str(processed_output_csv).replace(".csv", ".parquet") if debug: - debug_remap_check(remap_path=Path(root) / cfg.data_task.remap_path, - path_to_processed_chrom=Path(output_parts_folder)/"chromY_processed.csv", - chrom="Y", json_dir=json_dir) - debug_fimo_check(path_to_chrom_fimo=Path(root) / cfg.data_task.fimo_out_dir / "chromY" / "fimo_annotations.csv", - path_to_processed_chrom=Path(output_parts_folder)/"chromY_processed.csv", - chrom="Y", json_dir=json_dir) + debug_remap_check( + remap_path=Path(root) / cfg.data_task.remap_path, + path_to_processed_chrom=Path(output_parts_folder) / "chromY_processed.csv", + chrom="Y", + json_dir=json_dir, + ) + debug_fimo_check( + path_to_chrom_fimo=Path(root) + / cfg.data_task.fimo_out_dir + / "chromY" + / "fimo_annotations.csv", + path_to_processed_chrom=Path(output_parts_folder) / "chromY_processed.csv", + chrom="Y", + json_dir=json_dir, + ) out_path = out_path.replace(".parquet", "_debug.parquet") - + logger.info(f"Combining {len(paths_to_processed_dfs)} processed parts with Polars") combine_processed_with_polars( paths_to_processed_dfs=paths_to_processed_dfs, idmap_path=clean_idmap_path, out_path=out_path, max_protein_len=cfg.data_task.max_protein_len, - seeds=cfg.data_task.seeds + seeds=cfg.data_task.seeds, ) - + # Delete the folder that had the temporary DFs, don't need these if False and output_parts_folder.exists(): for f in output_parts_folder.glob("*.csv"): @@ -786,6 +1012,7 @@ def main(cfg: DictConfig): output_parts_folder.rmdir() logger.info(f"Cleaned up temporary files in {output_parts_folder}") + if __name__ == "__main__": # On some clusters with older Python, 'fork' is default and fine. # If you hit issues (e.g., with threads/IO), uncomment spawn: diff --git a/dpacman/data_tasks/fimo/pre_fimo.py b/dpacman/data_tasks/fimo/pre_fimo.py index 6373a5f948c0682bc3c440ddef52012c2a59b136..65d1d13f43c7ebff58ee1aa5983deef83cef3fd4 100644 --- a/dpacman/data_tasks/fimo/pre_fimo.py +++ b/dpacman/data_tasks/fimo/pre_fimo.py @@ -1,8 +1,6 @@ #!/usr/bin/env python3 import pandas as pd import numpy as np -import rootutils -import logging import os import json import multiprocessing as mp @@ -10,9 +8,13 @@ from multiprocessing import Pool, cpu_count from omegaconf import DictConfig from pathlib import Path from hydra.core.hydra_config import HydraConfig +import rootutils +from dpacman.utils import pylogger +import logging root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) -logger = logging.getLogger(__name__) +logger = pylogger.RankedLogger(__name__, rank_zero_only=True) + def init_worker(log_file, logger_name): """Initialize a logger in each worker.""" @@ -23,13 +25,16 @@ def init_worker(log_file, logger_name): # Avoid re-adding handlers if this logger is reused if not wlogger.handlers: handler = logging.FileHandler(log_file) - formatter = logging.Formatter('%(asctime)s - %(message)s') + formatter = logging.Formatter("%(asctime)s - %(message)s") handler.setFormatter(formatter) wlogger.addHandler(handler) - + return wlogger - -def process_chromosome(chrom, df, json_dir, output_path, example_dir, save_example_files, log_dir): + + +def process_chromosome( + chrom, df, json_dir, output_path, example_dir, save_example_files, log_dir +): log_file = Path(log_dir) / f"chrom_{chrom}.log" logger_name = f"logger_chrom_{chrom}" wlogger = init_worker(log_file, logger_name) @@ -52,13 +57,16 @@ def process_chromosome(chrom, df, json_dir, output_path, example_dir, save_examp f.write("".join(lines[:50])) wlogger.info(f"Saved example: {example_out}") -def assemble_main_input(input_csv: str, window_total: int, output_csv: str, save_example_files: bool): + +def assemble_main_input( + input_csv: str, window_total: int, output_csv: str, save_example_files: bool +): """ - Method for assembling the main input dataframe - + Method for assembling the main input dataframe + Args: - input_csv: path to input csv that is converted to the file for FIMO input - - window_total: int determining the total non-ChIPseq-peak nucleotides included in a datapoint + - window_total: int determining the total non-ChIPseq-peak nucleotides included in a datapoint - output_csv: where processed file will be saved - save_example_files: bool determining whether we save example files that can be easily viewed """ @@ -66,11 +74,11 @@ def assemble_main_input(input_csv: str, window_total: int, output_csv: str, save input_path = Path(root) / input_csv df = pd.read_csv(input_path, sep="\t") out = None # initialize out - + output_path = Path(root) / output_csv os.makedirs(output_path.parent, exist_ok=True) - if not(os.path.exists(output_path)): + if not (os.path.exists(output_path)): # 2) normalize chromosomes and exclude non-whole chromosomes df["chrom"] = df["chrom"].str.replace(r"^chr", "", regex=True) @@ -117,19 +125,17 @@ def assemble_main_input(input_csv: str, window_total: int, output_csv: str, save # 8) write csv out.to_csv(output_path, index=False) logger.info(f"Wrote {len(out)} rows to {output_path}") - + # Load the DF if we need if out is None: out = pd.read_csv(output_path) - + # 9) write example csv if necessary if save_example_files: example_dir = output_path.parent / "examples" os.makedirs(example_dir, exist_ok=True) output_csv_name = output_csv.split("/")[-1] - example_savepath = os.path.join( - example_dir, "example500_" + output_csv_name - ) + example_savepath = os.path.join(example_dir, "example500_" + output_csv_name) if not (os.path.exists(example_savepath)): out.sample(n=500, random_state=42).reset_index(drop=True).to_csv( @@ -138,9 +144,10 @@ def assemble_main_input(input_csv: str, window_total: int, output_csv: str, save logger.info( f"Saved example FIMO input file with 500 rows to: {example_savepath}" ) - + return out + def load_chrom_dna(chrom, cache, json_dir): """ Load DNA from the chromosome that we pre-downloaded @@ -155,15 +162,18 @@ def load_chrom_dna(chrom, cache, json_dir): cache[chrom] = json.load(f)["dna"] return cache[chrom] -def parallel_make_all_fasta_inputs(df, json_dir, output_path, example_dir, save_example_files=True, max_workers=8): + +def parallel_make_all_fasta_inputs( + df, json_dir, output_path, example_dir, save_example_files=True, max_workers=8 +): df["#chrom"] = df["#chrom"].astype(str) chromosomes = df["#chrom"].unique().tolist() - + log_dir = Path(HydraConfig.get().run.dir) / "logs" - + os.makedirs(log_dir, exist_ok=True) logger.info(f"Created {log_dir} for storing logs for subprocesses.") - + os.makedirs(example_dir, exist_ok=True) logger.info(f"Created {example_dir} for storing example inputs") @@ -175,16 +185,17 @@ def parallel_make_all_fasta_inputs(df, json_dir, output_path, example_dir, save_ with mp.Pool(processes=max_workers) as pool: pool.starmap(process_chromosome, args) + def extract_sequences(df, seq_fasta, json_dir, wlogger): """ - Make the main sequence fasta for this chromosome. Used for building the background model. + Make the main sequence fasta for this chromosome. Used for building the background model. """ dna_cache = {} n_rows = len(df) checkpoints = set(int(n_rows * i / 100) for i in range(1, 101)) # 1% to 100% - + wlogger.info(f"Writing to {seq_fasta}") - if not(os.path.exists(seq_fasta)): + if not (os.path.exists(seq_fasta)): with open(seq_fasta, "w") as fa: for idx, row in df.iterrows(): chrom = str(row["#chrom"]) @@ -192,22 +203,27 @@ def extract_sequences(df, seq_fasta, json_dir, wlogger): dna = load_chrom_dna(chrom, dna_cache, json_dir) start = int(row["contextStart"]) end = int(row["contextEnd"]) - seq = dna[start:end] # end index is not included in ChIP-seq peaks + seq = dna[start:end] # end index is not included in ChIP-seq peaks header = f"{idx}_chr{chrom}_{tr}_{start}_{end}" fa.write(f">{header}\n{seq}\n") - - # log every 1% + + # log every 1% if idx in checkpoints: - wlogger.info(f" Reached {idx / n_rows:.0%} of the DataFrame (index {idx})") + wlogger.info( + f" Reached {idx / n_rows:.0%} of the DataFrame (index {idx})" + ) + def main(cfg: DictConfig): # 1) make the full input CSV paths = cfg.data_task.paths - df = assemble_main_input(input_csv=paths.input_csv, - window_total=cfg.data_task.window_total, - output_csv=paths.output_csv, - save_example_files=cfg.data_task.save_example_files) - + df = assemble_main_input( + input_csv=paths.input_csv, + window_total=cfg.data_task.window_total, + output_csv=paths.output_csv, + save_example_files=cfg.data_task.save_example_files, + ) + # Make example dir to use in future methods example_dir = Path(root) / paths.output_csv example_dir = example_dir.parent / "examples" @@ -219,14 +235,15 @@ def main(cfg: DictConfig): logger.info(f"Max workers available (cpu_count - 1): {max_workers}") max_workers = min(max_workers, total_chroms) logger.info(f"min(max_workers, total_chroms) = {max_workers}") - - parallel_make_all_fasta_inputs(df, - json_dir=paths.json_dir, - output_path=Path(root) / paths.chrom_output_path, - example_dir=example_dir, - save_example_files=cfg.data_task.save_example_files, - max_workers=max_workers) + parallel_make_all_fasta_inputs( + df, + json_dir=paths.json_dir, + output_path=Path(root) / paths.chrom_output_path, + example_dir=example_dir, + save_example_files=cfg.data_task.save_example_files, + max_workers=max_workers, + ) if __name__ == "__main__": diff --git a/dpacman/data_tasks/fimo/run_fimo.py b/dpacman/data_tasks/fimo/run_fimo.py index 09c0e9f5c53107374e48eebcfc3f2aea8117446c..4cbe0a8129f7ae512d2df8c63b21a2a2b571984c 100644 --- a/dpacman/data_tasks/fimo/run_fimo.py +++ b/dpacman/data_tasks/fimo/run_fimo.py @@ -12,9 +12,12 @@ from pathlib import Path import time import shutil from hydra.core.hydra_config import HydraConfig +import rootutils +from dpacman.utils import pylogger root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) -logger = logging.getLogger(__name__) +logger = pylogger.RankedLogger(__name__, rank_zero_only=True) + def run_markov(fasta_get_markov, seq_fasta, bg_model): subprocess.check_call( @@ -23,10 +26,13 @@ def run_markov(fasta_get_markov, seq_fasta, bg_model): stderr=subprocess.DEVNULL, ) -def split_fasta(n_chunks, input_file, output_dir, debug=False, debug_n=1000, all_caps=True): + +def split_fasta( + n_chunks, input_file, output_dir, debug=False, debug_n=1000, all_caps=True +): """ Round-robin split SEQ_FASTA into chunked FASTA files. - If in debug mode, only keep the first 5 entries for each. + If in debug mode, only keep the first 5 entries for each. """ output_dir = Path(root) / output_dir out_names = [os.path.join(output_dir, f"to_scan_{i}.fa") for i in range(n_chunks)] @@ -34,11 +40,11 @@ def split_fasta(n_chunks, input_file, output_dir, debug=False, debug_n=1000, all chunk_counts = [0] * n_chunks # Count sequences per chunk logger.info(f"ALL CAPS mode: {all_caps}") - + with open(input_file) as inf: header = None seq_lines = [] - + for line in inf: if line.startswith(">"): if header is not None: @@ -46,7 +52,8 @@ def split_fasta(n_chunks, input_file, output_dir, debug=False, debug_n=1000, all if not debug or chunk_counts[idx] < debug_n: out_handles[idx].write(header) seqj = "".join(seq_lines) - if all_caps: seqj = seqj.upper() + if all_caps: + seqj = seqj.upper() out_handles[idx].write(seqj) chunk_counts[idx] += 1 header = line @@ -60,17 +67,18 @@ def split_fasta(n_chunks, input_file, output_dir, debug=False, debug_n=1000, all if not debug or chunk_counts[idx] < debug_n: out_handles[idx].write(header) seqj = "".join(seq_lines) - if all_caps: seqj = seqj.upper() + if all_caps: + seqj = seqj.upper() out_handles[idx].write(seqj) chunk_counts[idx] += 1 - + for o in out_handles: o.close() - + # Log chunk sizes for i, count in enumerate(chunk_counts): logger.info(f"Chunk {i}: {count} sequences") - + return out_names @@ -98,7 +106,7 @@ def run_fimo_chunk(cfg): wlogger = logging.getLogger(f"fimo_chunk_{chunk_id}") wlogger.setLevel(logging.DEBUG) wlogger.propagate = False # Don't double-log to root - + outdir = Path(cfg["outdir"]) os.makedirs(outdir, exist_ok=True) @@ -106,29 +114,29 @@ def run_fimo_chunk(cfg): fh = logging.FileHandler(log_file, mode="w", encoding="utf-8") fh.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) wlogger.addHandler(fh) - - # make an output directory for this chromosome + + # make an output directory for this chromosome wlogger.info(f"Chunk {cfg['chunk_id']} starting FIMO") wlogger.info(f"Threshold mode: {cfg['thresh_mode']}") - + try: call_list = [ - cfg["fimo_bin"], - "--oc", - outdir, - "--bfile", - cfg["bg_model"], - "--max-stored-scores", - str(cfg["max_stored"]), - "--thresh", - str(cfg["thresh"]), - "--qv-thresh", # threshold on q-value - "--no-pgc", # suppress parsing of genomic coordinates in FASTA sequence header - cfg["motif_file"], - cfg["fasta_path"], - ] - if cfg["thresh_mode"]!="q": - call_list = [x for x in call_list if x!="--qv-thresh"] + cfg["fimo_bin"], + "--oc", + outdir, + "--bfile", + cfg["bg_model"], + "--max-stored-scores", + str(cfg["max_stored"]), + "--thresh", + str(cfg["thresh"]), + "--qv-thresh", # threshold on q-value + "--no-pgc", # suppress parsing of genomic coordinates in FASTA sequence header + cfg["motif_file"], + cfg["fasta_path"], + ] + if cfg["thresh_mode"] != "q": + call_list = [x for x in call_list if x != "--qv-thresh"] assert "--qv-thresh" not in call_list with open(log_file, "a") as log_fh: subprocess.check_call( @@ -137,41 +145,63 @@ def run_fimo_chunk(cfg): stderr=log_fh, ) wlogger.info(f"\tChunk {cfg['chunk_id']} finished") - + # Delete the file - gotta save space! file_path = Path(cfg["fasta_path"]) if file_path.exists() and file_path.is_file(): file_path.unlink() wlogger.info(f"\tDeleted file: {file_path}") - + except subprocess.CalledProcessError as e: wlogger.error(f"\tChunk {chunk_id}: FIMO failed with error code {e.returncode}") raise return os.path.join(outdir, f"fimo.tsv") + def annotate_with_fimo(df, fdf): - df = df.reset_index().rename(columns={"index":"idx"}) - df["sequence_name"] = df["idx"].astype(str) + "_chr" + df["#chrom"] + "_" + df["TR"] + "_" + df["contextStart"].astype(str) + "_" + df["contextEnd"].astype(str) #construt it the same way as headers - - # Crucial: filter FDF results to only rows where the TF whose motif was found actually matches the TF that was detected there. - fdf["input_tr"] = fdf["sequence_name"].str.split("_",expand=True)[2] - true_matches = fdf.loc[ - fdf["motif_alt_id"]==fdf["input_tr"] - ].reset_index(drop=True) + df = df.reset_index().rename(columns={"index": "idx"}) + df["sequence_name"] = ( + df["idx"].astype(str) + + "_chr" + + df["#chrom"] + + "_" + + df["TR"] + + "_" + + df["contextStart"].astype(str) + + "_" + + df["contextEnd"].astype(str) + ) # construt it the same way as headers + + # Crucial: filter FDF results to only rows where the TF whose motif was found actually matches the TF that was detected there. + fdf["input_tr"] = fdf["sequence_name"].str.split("_", expand=True)[2] + true_matches = fdf.loc[fdf["motif_alt_id"] == fdf["input_tr"]].reset_index( + drop=True + ) logger.info(f"Length of full returned FIMO results: {len(fdf)}") - logger.info(f"Length of true matches, where the FIMO tr and the input tr match: {len(true_matches)}") - - true_matches = true_matches.merge(df[["sequence_name", "contextStart"]], on="sequence_name", how="left") - true_matches["genomic_start"] = true_matches["contextStart"] + true_matches["start"] - 1 + logger.info( + f"Length of true matches, where the FIMO tr and the input tr match: {len(true_matches)}" + ) + + true_matches = true_matches.merge( + df[["sequence_name", "contextStart"]], on="sequence_name", how="left" + ) + true_matches["genomic_start"] = ( + true_matches["contextStart"] + true_matches["start"] - 1 + ) true_matches["genomic_end"] = true_matches["contextStart"] + true_matches["stop"] true_matches["coord"] = ( - true_matches["genomic_start"].astype(str) + "-" + true_matches["genomic_end"].astype(str) + true_matches["genomic_start"].astype(str) + + "-" + + true_matches["genomic_end"].astype(str) ) - agg = true_matches.groupby("sequence_name")["coord"].agg(lambda hits: ",".join(hits)) + agg = true_matches.groupby("sequence_name")["coord"].agg( + lambda hits: ",".join(hits) + ) df["jaspar"] = df["sequence_name"].map(agg).fillna("") return df + def main(cfg: DictConfig): """ Main method for running FIMO analysis, searching JASPAR motifs against ChIP-seq peaks @@ -180,7 +210,7 @@ def main(cfg: DictConfig): paths = cfg.data_task.paths fimo = cfg.data_task.fimo meme = cfg.data_task.meme - + # set njobs to max or whatever # is specified by user njobs = fimo.njobs if njobs == "max": @@ -195,52 +225,72 @@ def main(cfg: DictConfig): if cfg.data_task.debug: chroms = chroms[0:1] logging.info(f" DEBUG MODE: running on only one chromosome: {chroms}") - + # 2) extract sequences & build BG model - for chrom in chroms: - path_to_fasta = Path(root) / Path(paths.input_fasta_outer_dir) / f"chr{chrom}" / paths.seq_fasta - path_to_bg = Path(root) / Path(paths.input_fasta_outer_dir) / f"chr{chrom}" / paths.bg_model + for chrom in chroms: + path_to_fasta = ( + Path(root) + / Path(paths.input_fasta_outer_dir) + / f"chr{chrom}" + / paths.seq_fasta + ) + path_to_bg = ( + Path(root) + / Path(paths.input_fasta_outer_dir) + / f"chr{chrom}" + / paths.bg_model + ) logging.info(f"Path to fasta file: {path_to_fasta}") logger.info(f"Building background model at {path_to_bg}…") - run_markov(Path(root)/meme.fasta_get_markov, path_to_fasta, Path(root) / path_to_bg) + run_markov( + Path(root) / meme.fasta_get_markov, path_to_fasta, Path(root) / path_to_bg + ) # 3) chunk FASTA and run FIMO in parallel - # make a folder to store the split fastas + # make a folder to store the split fastas chunk_folder = Path(path_to_fasta.parent) / "chunks" os.makedirs(chunk_folder, exist_ok=True) logger.info(f"Made directory {chunk_folder} to store {njobs} chunked fastas") - chunks = split_fasta(njobs, input_file=path_to_fasta, output_dir=chunk_folder, debug=cfg.data_task.debug, all_caps=cfg.data_task.all_caps) - + chunks = split_fasta( + njobs, + input_file=path_to_fasta, + output_dir=chunk_folder, + debug=cfg.data_task.debug, + all_caps=cfg.data_task.all_caps, + ) + chrom_outdir = Path(root) / paths.fimo_outdir / f"chrom{chrom}" os.makedirs(chrom_outdir, exist_ok=True) - + chunk_cfgs = [ dict( chunk_id=i, fasta_path=chunk, - fimo_outdir=Path(root)/ paths.fimo_outdir, + fimo_outdir=Path(root) / paths.fimo_outdir, fimo_bin=Path(root) / meme.fimo_bin, bg_model=path_to_bg, max_stored=fimo.max_stored, motif_file=Path(root) / meme.jaspar_motif_file, thresh=fimo.thresh, thresh_mode=fimo.thresh_mode, - outdir=Path(chrom_outdir) / f"chunk{i}" + outdir=Path(chrom_outdir) / f"chunk{i}", ) for i, chunk in enumerate(chunks) ] logger.info(f"Running FIMO in parallel ({njobs} jobs)…") start_time = time.time() - # Call the parallel jobs and get back a list of tsv paths + # Call the parallel jobs and get back a list of tsv paths with Pool(njobs) as pool: tsv_paths = pool.map(run_fimo_chunk, chunk_cfgs) end_time = time.time() - logger.info(f"COMPLETED FIMO ({njobs} parallel jobs) in {end_time-start_time:.2f}s") - # cleanup! delete the chunked input files + logger.info( + f"COMPLETED FIMO ({njobs} parallel jobs) in {end_time-start_time:.2f}s" + ) + # cleanup! delete the chunked input files if not any(chunk_folder.iterdir()): # Empty folder chunk_folder.rmdir() logger.info(f"Deleted empty folder: {chunk_folder}") - + # 4) merge chunked TSVs. Some may be empty, so can't do a simple loop # delete intermediate folders as we go dfs = [] @@ -254,7 +304,7 @@ def main(cfg: DictConfig): except Exception as e: logger.error(f"Error reading {tsv}: {e}") raise # Or continue, depending on your needs - + # delete this folder to save storage chunk_dir = Path(tsv).parent try: @@ -264,16 +314,18 @@ def main(cfg: DictConfig): logger.warning(f"Could not delete chunk dir {chunk_dir}: {e}") combined = pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame() - + # 5) annotate & write final CSV df = pd.read_csv(Path(root) / paths.input_csv, low_memory=False) df["#chrom"] = df["#chrom"].astype(str) - df = df.loc[df["#chrom"]==chrom].reset_index(drop=True) + df = df.loc[df["#chrom"] == chrom].reset_index(drop=True) output_full_csv_path = Path(root) / chrom_outdir / f"fimo_annotations.csv" combined.to_csv(output_full_csv_path, index=False) - logger.info(f"Merging FIMO results into input DataFrame, which has {len(df)} rows for chromosome {chrom}") + logger.info( + f"Merging FIMO results into input DataFrame, which has {len(df)} rows for chromosome {chrom}" + ) df = annotate_with_fimo(df, combined) - + final = df[ [ "#chrom", @@ -290,5 +342,6 @@ def main(cfg: DictConfig): final.to_csv(output_csv_path, index=False) logger.info(f"Wrote {len(final)} rows to {output_csv_path}") + if __name__ == "__main__": main() diff --git a/dpacman/data_tasks/split/remap.py b/dpacman/data_tasks/split/remap.py index 8f630ca323f842e14f633748f26cf705a38a073b..92f0a8951321fd6ddf6dd556e062fbeb451bb809 100644 --- a/dpacman/data_tasks/split/remap.py +++ b/dpacman/data_tasks/split/remap.py @@ -1,33 +1,44 @@ from collections import Counter, defaultdict from ortools.linear_solver import pywraplp import random -import logging from omegaconf import DictConfig -import rootutils import pandas as pd from pathlib import Path import os import numpy as np from sklearn.model_selection import train_test_split +from dpacman.data_tasks.fimo.post_fimo import get_reverse_complement +import json +import rootutils +from dpacman.utils import pylogger root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) -logger = logging.getLogger(__name__) +logger = pylogger.RankedLogger(__name__, rank_zero_only=True) + def split_bipartite_fast( dna_clusters, - split_names=("train","val","test"), - ratios=(0.8,0.1,0.1), + split_names=("train", "val", "test"), + ratios=(0.8, 0.1, 0.1), ): # use sklearn test_size_1 = 0.2 test_size_2 = 0.5 - logger.info(f"\tPerforming first split: all clusters -> train clusters ({round(1-test_size_1,3)}) and other ({test_size_1})") + logger.info( + f"\tPerforming first split: all clusters -> train clusters ({round(1-test_size_1,3)}) and other ({test_size_1})" + ) X = dna_clusters - y = [0]*len(dna_clusters) - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size_1, random_state=0) - logger.info(f"\tPerforming second split: other -> val clusters ({round(1-test_size_2,3)}) and test clusters ({test_size_2})") - X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=test_size_2, random_state=0) - + y = [0] * len(dna_clusters) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=test_size_1, random_state=0 + ) + logger.info( + f"\tPerforming second split: other -> val clusters ({round(1-test_size_2,3)}) and test clusters ({test_size_2})" + ) + X_val, X_test, y_val, y_test = train_test_split( + X_test, y_test, test_size=test_size_2, random_state=0 + ) + dna_assign = {} for x in X_train: dna_assign[x] = "train" @@ -35,20 +46,21 @@ def split_bipartite_fast( dna_assign[x] = "val" for x in X_test: dna_assign[x] = "test" - - kept_by_split = {'train': len(X_train), 'val': len(X_val), 'test': len(X_test)} + + kept_by_split = {"train": len(X_train), "val": len(X_val), "test": len(X_test)} return dna_assign, kept_by_split + def split_bipartite_with_ratios_and_leaky( edges, - split_names=("train","val","test"), + split_names=("train", "val", "test"), ratios=(0.8, 0.1, 0.1), require_nonempty=False, - ratio_tolerance=None, # None = soft ratios only; 0.0 = exact band (use with care) + ratio_tolerance=None, # None = soft ratios only; 0.0 = exact band (use with care) bigM=None, shuffle_within_pair=False, seed=0, - test_edges_must=None, # NEW: list of (tf,dna) with duplicates OR dict {(tf,dna): count} + test_edges_must=None, # NEW: list of (tf,dna) with duplicates OR dict {(tf,dna): count} ): """ edges: list of (tf_cluster_id, dna_cluster_id). Duplicates allowed (-> weights). @@ -66,11 +78,11 @@ def split_bipartite_with_ratios_and_leaky( """ # Aggregate counts per pair w = Counter(edges) - tfs = {t for (t, _) in w} + tfs = {t for (t, _) in w} dnas = {d for (_, d) in w} - S = list(split_names) - rs = dict(zip(S, ratios)) - N = sum(w.values()) + S = list(split_names) + rs = dict(zip(S, ratios)) + N = sum(w.values()) if bigM is None: bigM = 1000 * max(1, N) @@ -90,7 +102,9 @@ def split_bipartite_with_ratios_and_leaky( if isinstance(test_edges_must, dict): for k, v in test_edges_must.items(): if not isinstance(k, tuple) or len(k) != 2: - raise ValueError("test_edges_must dict keys must be (tf_cluster, dna_cluster)") + raise ValueError( + "test_edges_must dict keys must be (tf_cluster, dna_cluster)" + ) if v < 0: raise ValueError("required_count must be non-negative") if v: @@ -113,39 +127,41 @@ def split_bipartite_with_ratios_and_leaky( raise RuntimeError("Could not create CBC solver.") # Binary cluster assignments - x = {(c,s): solver.BoolVar(f"x[{c},{s}]") for c in tfs for s in S} - y = {(d,s): solver.BoolVar(f"y[{d},{s}]") for d in dnas for s in S} + x = {(c, s): solver.BoolVar(f"x[{c},{s}]") for c in tfs for s in S} + y = {(d, s): solver.BoolVar(f"y[{d},{s}]") for d in dnas for s in S} # Each cluster in exactly one split for c in tfs: - solver.Add(sum(x[c,s] for s in S) == 1) + solver.Add(sum(x[c, s] for s in S) == 1) for d in dnas: - solver.Add(sum(y[d,s] for s in S) == 1) + solver.Add(sum(y[d, s] for s in S) == 1) # Integer kept counts per pair and split (allow partial within-pair) - k = {((c,d),s): solver.IntVar(0, w[(c,d)], f"k[{c},{d},{s}]") for (c,d) in w for s in S} + k = { + ((c, d), s): solver.IntVar(0, w[(c, d)], f"k[{c},{d},{s}]") + for (c, d) in w + for s in S + } # Only keep in split s if both endpoint clusters are assigned to s - for (c,d), wt in w.items(): + for (c, d), wt in w.items(): for s in S: - solver.Add(k[((c,d),s)] <= wt * x[c,s]) - solver.Add(k[((c,d),s)] <= wt * y[d,s]) + solver.Add(k[((c, d), s)] <= wt * x[c, s]) + solver.Add(k[((c, d), s)] <= wt * y[d, s]) # Enforce minimum kept counts in TEST for required pairs - for (c,d), req in req_test.items(): - solver.Add(k[((c,d), "test")] >= req) + for (c, d), req in req_test.items(): + solver.Add(k[((c, d), "test")] >= req) # Optional: ensure each split has at least one cluster (feasibility depends on counts) if require_nonempty: for s in S: - solver.Add( - sum(x[c,s] for c in tfs) + sum(y[d,s] for d in dnas) >= 1 - ) + solver.Add(sum(x[c, s] for c in tfs) + sum(y[d, s] for d in dnas) >= 1) # Kept counts per split and total K = {s: solver.IntVar(0, N, f"K[{s}]") for s in S} for s in S: - solver.Add(K[s] == sum(k[((c,d),s)] for (c,d) in w)) + solver.Add(K[s] == sum(k[((c, d), s)] for (c, d) in w)) T = solver.IntVar(0, N, "T") solver.Add(T == sum(K[s] for s in S)) @@ -153,7 +169,7 @@ def split_bipartite_with_ratios_and_leaky( dpos = {s: solver.NumVar(0, solver.infinity(), f"dpos[{s}]") for s in S} dneg = {s: solver.NumVar(0, solver.infinity(), f"dneg[{s}]") for s in S} for s in S: - solver.Add(K[s] - rs[s]*T == dpos[s] - dneg[s]) + solver.Add(K[s] - rs[s] * T == dpos[s] - dneg[s]) # Optional hard band around target ratios if ratio_tolerance is not None: @@ -172,11 +188,13 @@ def split_bipartite_with_ratios_and_leaky( status = solver.Solve() if status not in (pywraplp.Solver.OPTIMAL, pywraplp.Solver.FEASIBLE): - raise RuntimeError("No feasible solution (check ratio_tolerance vs. required test edges).") + raise RuntimeError( + "No feasible solution (check ratio_tolerance vs. required test edges)." + ) # Read cluster assignments - tf_assign = {c: next(s for s in S if x[c,s].solution_value() > 0.5) for c in tfs} - dna_assign = {d: next(s for s in S if y[d,s].solution_value() > 0.5) for d in dnas} + tf_assign = {c: next(s for s in S if x[c, s].solution_value() > 0.5) for c in tfs} + dna_assign = {d: next(s for s in S if y[d, s].solution_value() > 0.5) for d in dnas} # Kept counts per split kept_by_split = {s: int(round(K[s].solution_value())) for s in S} @@ -187,13 +205,13 @@ def split_bipartite_with_ratios_and_leaky( remaining_indices = {pair: list(pair_to_indices[pair]) for pair in pair_to_indices} # Allocate the kept examples per split (train/val/test) - for (c,d), wt in w.items(): + for (c, d), wt in w.items(): for s in S: - cnt = int(round(k[((c,d),s)].solution_value())) + cnt = int(round(k[((c, d), s)].solution_value())) if cnt > 0: - take = remaining_indices[(c,d)][:cnt] + take = remaining_indices[(c, d)][:cnt] split_to_indices[s].extend(take) - remaining_indices[(c,d)] = remaining_indices[(c,d)][cnt:] + remaining_indices[(c, d)] = remaining_indices[(c, d)][cnt:] # Everything left becomes leaky_test leaky_indices = [] @@ -202,49 +220,63 @@ def split_bipartite_with_ratios_and_leaky( leaky_indices.extend(idxs) split_to_indices["leaky_test"] = leaky_indices - split_to_edges = {s: [edges[i] for i in split_to_indices[s]] for s in split_to_indices} + split_to_edges = { + s: [edges[i] for i in split_to_indices[s]] for s in split_to_indices + } - return tf_assign, dna_assign, kept_by_split, total_kept, split_to_indices, split_to_edges + return ( + tf_assign, + dna_assign, + kept_by_split, + total_kept, + split_to_indices, + split_to_edges, + ) -from collections import Counter, defaultdict -import random class DSU: - def __init__(self): self.p = {} + def __init__(self): + self.p = {} + def find(self, x): - if x not in self.p: self.p[x] = x + if x not in self.p: + self.p[x] = x while self.p[x] != x: self.p[x] = self.p[self.p[x]] x = self.p[x] return x - def union(self, a,b): + + def union(self, a, b): ra, rb = self.find(a), self.find(b) - if ra != rb: self.p[rb] = ra + if ra != rb: + self.p[rb] = ra + def split_bipartite_by_components( edges, - split_names=("train","val","test"), - ratios=(0.8,0.1,0.1), + split_names=("train", "val", "test"), + ratios=(0.8, 0.1, 0.1), seed=0, require_nonempty=False, - test_edges_must=None, # None, list[(tf,dna)], or dict{(tf,dna): count} + test_edges_must=None, # None, list[(tf,dna)], or dict{(tf,dna): count} ): """ Guarantees exclusivity: each TF cluster and DNA cluster appears in at most one split. Strategy: find connected components in the TF–DNA bipartite graph and assign components wholesale. """ rng = random.Random(seed) - w = Counter(edges) # multiplicities per pair - if not w: raise ValueError("No edges.") + w = Counter(edges) # multiplicities per pair + if not w: + raise ValueError("No edges.") # 1) Build components with Union-Find (prefix to keep TF/DNA namespaces disjoint) dsu = DSU() - for (tf, dna) in w: + for tf, dna in w: dsu.union(("T", tf), ("D", dna)) comp_pairs = defaultdict(list) comp_weight = defaultdict(int) for (tf, dna), cnt in w.items(): - root = dsu.find(("T", tf)) # component id = root of TF endpoint + root = dsu.find(("T", tf)) # component id = root of TF endpoint comp_pairs[root].append((tf, dna)) comp_weight[root] += cnt @@ -256,18 +288,26 @@ def split_bipartite_by_components( target = {s: int(round(rs[s] * N)) for s in S} # 2) Pin components that contain required TEST pairs - pinned = {} # comp_root -> pinned_split ("test") + pinned = {} # comp_root -> pinned_split ("test") if test_edges_must: - req = Counter(test_edges_must) if not isinstance(test_edges_must, dict) else Counter(test_edges_must) + req = ( + Counter(test_edges_must) + if not isinstance(test_edges_must, dict) + else Counter(test_edges_must) + ) # Map each required pair to its component, ensure feasibility for (tf, dna), r in req.items(): if (tf, dna) not in w: raise ValueError(f"Required pair {(tf,dna)} not present.") if r > w[(tf, dna)]: - raise ValueError(f"Required count {r} for {(tf,dna)} exceeds available {w[(tf,dna)]}.") + raise ValueError( + f"Required count {r} for {(tf,dna)} exceeds available {w[(tf,dna)]}." + ) comp = dsu.find(("T", tf)) if comp in pinned and pinned[comp] != "test": - raise ValueError(f"Component conflict: already pinned to {pinned[comp]}, but {(tf,dna)} demands test.") + raise ValueError( + f"Component conflict: already pinned to {pinned[comp]}, but {(tf,dna)} demands test." + ) pinned[comp] = "test" # NOTE: pinning a pair pins the WHOLE component to test (to keep exclusivity). # If you only want some edges kept in test and discard the rest, handle below when materializing. @@ -287,7 +327,7 @@ def split_bipartite_by_components( # Ensure nonempty splits if requested (seed with largest remaining comps) if require_nonempty: - seeds = remaining[:min(len(S), len(remaining))] + seeds = remaining[: min(len(S), len(remaining))] for comp, s in zip(seeds, S): comp_assign[comp] = s kept_by_split[s] += comp_weight[comp] @@ -321,10 +361,12 @@ def split_bipartite_by_components( # (Left out for clarity; default is: keep the whole component in its split.) # 5) Build edge lists and simple cluster assignments - split_to_edges = {s: [edges[i] for i in split_to_indices[s]] for s in split_to_indices} + split_to_edges = { + s: [edges[i] for i in split_to_indices[s]] for s in split_to_indices + } tf_assign, dna_assign = {}, {} for comp, s in comp_assign.items(): - for (tf, dna) in comp_pairs[comp]: + for tf, dna in comp_pairs[comp]: tf_assign[tf] = s dna_assign[dna] = s @@ -339,38 +381,56 @@ def split_bipartite_by_components( dup_dna = {dn: ss for dn, ss in dna_in_split.items() if len(ss) > 1} assert not dup_tf and not dup_dna, f"Exclusivity violated: {dup_tf} {dup_dna}" - return tf_assign, dna_assign, kept_by_split, total_kept, split_to_indices, split_to_edges + return ( + tf_assign, + dna_assign, + kept_by_split, + total_kept, + split_to_indices, + split_to_edges, + ) + def print_split_ratios(kept_by_split): total = sum(kept_by_split.values()) train_pcnt = 100 * kept_by_split["train"] / total - val_pcnt = 100 * kept_by_split["val"] / total - test_pcnt = 100 * kept_by_split["test"] / total - logger.info(f"Cluster distribution - Train: {train_pcnt:.2f}%, Val: {val_pcnt:.2f}%, Test: {test_pcnt:.2f}%") + val_pcnt = 100 * kept_by_split["val"] / total + test_pcnt = 100 * kept_by_split["test"] / total + logger.info( + f"Cluster distribution - Train: {train_pcnt:.2f}%, Val: {val_pcnt:.2f}%, Test: {test_pcnt:.2f}%" + ) + -def make_edges(processed_fimo_path: str, protein_cluster_path: str, dna_cluster_path: str): +def make_edges( + processed_fimo_path: str, protein_cluster_path: str, dna_cluster_path: str +): """ Make edges for input to the splitting algorithm. Edges consist of: (tr_cluster_rep)_(dna_cluster_rep) where the cluster rep is the sequence ID """ # Read cluser data - protein_clusters = pd.read_csv(protein_cluster_path, header=None,sep="\t") - protein_clusters.columns=["tr_cluster_rep","tr_seqid"] - - dna_clusters = pd.read_csv(dna_cluster_path, header=None,sep="\t") - dna_clusters.columns=["dna_cluster_rep","dna_seqid"] + protein_clusters = pd.read_csv(protein_cluster_path, header=None, sep="\t") + protein_clusters.columns = ["tr_cluster_rep", "tr_seqid"] + + dna_clusters = pd.read_csv(dna_cluster_path, header=None, sep="\t") + dna_clusters.columns = ["dna_cluster_rep", "dna_seqid"] - # Read datapoints + # Read datapoints edges = pd.read_parquet(processed_fimo_path) - edges = pd.merge(edges, dna_clusters, on="dna_seqid",how="left") - edges = pd.merge(edges, protein_clusters, on="tr_seqid",how="left") - edges["edge"] = edges.apply(lambda row: (row["tr_cluster_rep"], row["dna_cluster_rep"]), axis=1) - + edges = pd.merge(edges, dna_clusters, on="dna_seqid", how="left") + edges = pd.merge(edges, protein_clusters, on="tr_seqid", how="left") + edges["edge"] = edges.apply( + lambda row: (row["tr_cluster_rep"], row["dna_cluster_rep"]), axis=1 + ) + logger.info(f"Total unique edges: {len(edges['edge'].unique().tolist())}") dup_edges = edges.loc[edges.duplicated("edge")]["edge"].unique().tolist() logger.info(f"Total edges with >1 datapoint: {len(dup_edges)}") - logger.info(f"Total datapoints belonging to a duplicate edge: {len(edges.loc[edges['edge'].isin(dup_edges)])}") + logger.info( + f"Total datapoints belonging to a duplicate edge: {len(edges.loc[edges['edge'].isin(dup_edges)])}" + ) return edges + def check_validity(train, val, test, split_by="both"): """ Rigorous check for no overlap @@ -379,137 +439,217 @@ def check_validity(train, val, test, split_by="both"): train_ids = set(train["ID"].unique().tolist()) val_ids = set(val["ID"].unique().tolist()) test_ids = set(test["ID"].unique().tolist()) - - assert len(train_ids.intersection(val_ids))==0 - assert len(train_ids.intersection(test_ids))==0 - assert len(val_ids.intersection(test_ids))==0 + + assert len(train_ids.intersection(val_ids)) == 0 + assert len(train_ids.intersection(test_ids)) == 0 + assert len(val_ids.intersection(test_ids)) == 0 logger.info(f"Pass! No overlap in IDs") - - if split_by!="dna": + + if split_by != "dna": train_tr_seqs = set(train["tr_sequence"].unique().tolist()) val_tr_seqs = set(val["tr_sequence"].unique().tolist()) test_tr_seqs = set(test["tr_sequence"].unique().tolist()) - - assert len(train_tr_seqs.intersection(val_tr_seqs))==0 - assert len(train_tr_seqs.intersection(test_tr_seqs))==0 - assert len(val_tr_seqs.intersection(test_tr_seqs))==0 + + assert len(train_tr_seqs.intersection(val_tr_seqs)) == 0 + assert len(train_tr_seqs.intersection(test_tr_seqs)) == 0 + assert len(val_tr_seqs.intersection(test_tr_seqs)) == 0 logger.info(f"Pass! No overlap in TR sequences") - + train_tr_reps = set(train["tr_cluster_rep"].unique().tolist()) val_tr_reps = set(val["tr_cluster_rep"].unique().tolist()) test_tr_reps = set(test["tr_cluster_rep"].unique().tolist()) - - assert len(train_tr_reps.intersection(val_tr_reps))==0 - assert len(train_tr_reps.intersection(test_tr_reps))==0 - assert len(val_tr_reps.intersection(test_tr_reps))==0 + + assert len(train_tr_reps.intersection(val_tr_reps)) == 0 + assert len(train_tr_reps.intersection(test_tr_reps)) == 0 + assert len(val_tr_reps.intersection(test_tr_reps)) == 0 logger.info(f"Pass! No overlap in TR cluster reps") - - if split_by!="protein": + + if split_by != "protein": train_dna_seqs = set(train["dna_sequence"].unique().tolist()) val_dna_seqs = set(val["dna_sequence"].unique().tolist()) test_dna_seqs = set(test["dna_sequence"].unique().tolist()) - - assert len(train_dna_seqs.intersection(val_dna_seqs))==0 - assert len(train_dna_seqs.intersection(test_dna_seqs))==0 - assert len(val_dna_seqs.intersection(test_dna_seqs))==0 + + assert len(train_dna_seqs.intersection(val_dna_seqs)) == 0 + assert len(train_dna_seqs.intersection(test_dna_seqs)) == 0 + assert len(val_dna_seqs.intersection(test_dna_seqs)) == 0 logger.info(f"Pass! No overlap in DNA sequences") - + train_dna_reps = set(train["dna_cluster_rep"].unique().tolist()) val_dna_reps = set(val["dna_cluster_rep"].unique().tolist()) test_dna_reps = set(test["dna_cluster_rep"].unique().tolist()) - - assert len(train_dna_reps.intersection(val_dna_reps))==0 - assert len(train_dna_reps.intersection(test_dna_reps))==0 - assert len(val_dna_reps.intersection(test_dna_reps))==0 + + assert len(train_dna_reps.intersection(val_dna_reps)) == 0 + assert len(train_dna_reps.intersection(test_dna_reps)) == 0 + assert len(val_dna_reps.intersection(test_dna_reps)) == 0 logger.info(f"Pass! No overlap in DNA cluster reps") + +def augment_rc(df): + """ + Get the reverse complement and add it as a datapoint, effectively doubling the dataset. + Also flip the orientation of the scores + + columns = ["ID","dna_sequence","tr_sequence","tr_cluster_rep","dna_cluster_rep", "scores","split"] + """ + df_rc = df.copy(deep=True) + + df_rc["dna_sequence"] = df_rc["dna_sequence"].apply( + lambda x: get_reverse_complement(x) + ) + df_rc["ID"] = df_rc["ID"] + "_rc" + df_rc["scores"] = df_rc["scores"].apply(lambda s: ",".join(s.split(",")[::-1])) + + final_df = pd.concat([df, df_rc]).reset_index(drop=True) + + return final_df + + def main(cfg: DictConfig): """ - Take a set of DNA clusters + protein clusters, and create the best possible splits into train/val/test. + Take a set of DNA clusters + protein clusters, and create the best possible splits into train/val/test. """ # construct edges from training data - edge_df = make_edges(processed_fimo_path=Path(root) / cfg.data_task.input_data_path, - protein_cluster_path=Path(root) / cfg.data_task.cluster_output_paths.protein, - dna_cluster_path=Path(root) / cfg.data_task.cluster_output_paths.dna) + edge_df = make_edges( + processed_fimo_path=Path(root) / cfg.data_task.input_data_path, + protein_cluster_path=Path(root) / cfg.data_task.cluster_output_paths.protein, + dna_cluster_path=Path(root) / cfg.data_task.cluster_output_paths.dna, + ) edges = edge_df["edge"].unique().tolist() - + # figure out if we actually even have a conflict total_proteins = len(edge_df["tr_seqid"].unique().tolist()) total_protein_clusters = len(edge_df["tr_cluster_rep"].unique().tolist()) - - no_protein_overlap = (total_proteins)==(total_protein_clusters) + + no_protein_overlap = (total_proteins) == (total_protein_clusters) logger.info(f"All proteins are in their own clusters: {no_protein_overlap}") - - if cfg.data_task.split_by=="dna": + + if cfg.data_task.split_by == "dna": logger.info(f"Easy split: all proteins are in their own clusters.") dna_clusters = edge_df["dna_cluster_rep"].unique().tolist() results = split_bipartite_fast( dna_clusters, - split_names=("train","val","test"), - ratios=(cfg.data_task.train_ratio, cfg.data_task.val_ratio, cfg.data_task.test_ratio), + split_names=("train", "val", "test"), + ratios=( + cfg.data_task.train_ratio, + cfg.data_task.val_ratio, + cfg.data_task.test_ratio, + ), ) dna_assign, kept_by_split = results - + # assign datapoints to cluster by their DNA cluster rep edge_df["split"] = edge_df["dna_cluster_rep"].map(dna_assign) else: results = split_bipartite_by_components( edges, - split_names=("train","val","test"), - ratios=(cfg.data_task.train_ratio, cfg.data_task.val_ratio, cfg.data_task.test_ratio), + split_names=("train", "val", "test"), + ratios=( + cfg.data_task.train_ratio, + cfg.data_task.val_ratio, + cfg.data_task.test_ratio, + ), require_nonempty=cfg.data_task.require_nonempty, seed=cfg.data_task.seed, test_edges_must=None, ) - - tf_assign, dna_assign, kept_by_split, total_kept, split_to_indices, split_to_edges = results - + + ( + tf_assign, + dna_assign, + kept_by_split, + total_kept, + split_to_indices, + split_to_edges, + ) = results + # Map each sample to its split print(tf_assign) print(dna_assign) edge_df["tr_split"] = edge_df["tr_cluster_rep"].map(tf_assign) edge_df["dna_split"] = edge_df["dna_cluster_rep"].map(dna_assign) - edge_df["same_split"] = edge_df["tr_split"]==edge_df["dna_split"] # should always be true if easy cluster + edge_df["same_split"] = ( + edge_df["tr_split"] == edge_df["dna_split"] + ) # should always be true if easy cluster edge_df["split"] = edge_df["tr_split"] print(edge_df) edge_df["split"] = np.where( edge_df["same_split"], - edge_df["split"], # keep existing split if same_split == True - "leak" # otherwise leak + edge_df["split"], # keep existing split if same_split == True + "leak", # otherwise leak ) print(edge_df) - + # Print ratios: hopefully close to desired (e.g. 80/10/10) print_split_ratios(kept_by_split) # Make train, val, test sets # make sure no ID is duplicate - assert len(edge_df["ID"].unique())==len(edge_df) - split_cols = ["ID","dna_sequence","tr_sequence","tr_cluster_rep","dna_cluster_rep", "scores","split"] - train = edge_df.loc[ - edge_df["split"]=="train" - ].reset_index(drop=True)[split_cols] - val = edge_df.loc[ - edge_df["split"]=="val" - ].reset_index(drop=True)[split_cols] - test = edge_df.loc[ - edge_df["split"]=="test" - ].reset_index(drop=True)[split_cols] - + assert len(edge_df["ID"].unique()) == len(edge_df) + split_cols = [ + "ID", + "dna_sequence", + "tr_sequence", + "tr_cluster_rep", + "dna_cluster_rep", + "scores", + "split", + ] + train = edge_df.loc[edge_df["split"] == "train"].reset_index(drop=True)[split_cols] + val = edge_df.loc[edge_df["split"] == "val"].reset_index(drop=True)[split_cols] + test = edge_df.loc[edge_df["split"] == "test"].reset_index(drop=True)[split_cols] + # ensure there is no overlap check_validity(train, val, test, split_by=cfg.data_task.split_by) - - total = sum([len(train),len(val),len(test)]) + + total = sum([len(train), len(val), len(test)]) logger.info(f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)") logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)") logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)") logger.info(f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}") - + + og_unique_dna = pd.concat([train, val, test]) + og_unique_dna = len(og_unique_dna["dna_sequence"].unique()) + + ## Now do RC data augmentation if asked + if cfg.data_task.augment_rc: + train = augment_rc(train) + val = augment_rc(val) + test = augment_rc(test) + + logger.info(f"Added reverse complement sequences to train, val, and test.") + + check_validity(train, val, test, split_by=cfg.data_task.split_by) + + total = sum([len(train), len(val), len(test)]) + logger.info( + f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)" + ) + logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)") + logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)") + logger.info( + f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}" + ) + + # since we've added all these new DNA sequences, we do need a new apping of seq id to dna sequence + all_data = pd.concat([train, val, test]) + all_data["dna_seqid"] = all_data["ID"].str.split("_", n=1, expand=True)[1] + dna_dict = dict(zip(all_data["dna_seqid"], all_data["dna_sequence"])) + assert len(dna_dict) == len(all_data.drop_duplicates(["dna_sequence"])) + new_map_path = str(Path(root) / cfg.data_task.dna_map_path).replace( + ".json", "_with_rc.json" + ) + + with open(new_map_path, "w") as f: + json.dump(dna_dict, f, indent=2) + logger.info( + f"Saved DNA maps with reverse complements (len {len(dna_dict)}=2*original map of len {og_unique_dna}=={len(dna_dict)==2*og_unique_dna}) to {new_map_path}" + ) + # create the output dir - split_out_dir = Path(root)/cfg.data_task.split_out_dir + split_out_dir = Path(root) / cfg.data_task.split_out_dir os.makedirs(split_out_dir, exist_ok=True) - split_final_cols = ["ID","dna_sequence","tr_sequence","scores","split"] - train[split_final_cols].to_csv(split_out_dir/"train.csv", index=False) - val[split_final_cols].to_csv(split_out_dir/"val.csv", index=False) - test[split_final_cols].to_csv(split_out_dir/"test.csv", index=False) - logger.info(f"Saved all splits to {split_out_dir}") \ No newline at end of file + split_final_cols = ["ID", "dna_sequence", "tr_sequence", "scores", "split"] + train[split_final_cols].to_csv(split_out_dir / "train.csv", index=False) + val[split_final_cols].to_csv(split_out_dir / "val.csv", index=False) + test[split_final_cols].to_csv(split_out_dir / "test.csv", index=False) + logger.info(f"Saved all splits to {split_out_dir}") diff --git a/dpacman/scripts/eval.py b/dpacman/scripts/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dpacman/scripts/preprocess.py b/dpacman/scripts/preprocess.py index a0a25dac2dcd6bed7aa78ee4202735bf718ad8fc..3a2815fef1a6abfc958a7bf8fdc533d7cd849919 100644 --- a/dpacman/scripts/preprocess.py +++ b/dpacman/scripts/preprocess.py @@ -16,6 +16,8 @@ from dpacman.data_tasks.fimo.post_fimo import main as post_fimo_main from dpacman.data_tasks.cluster.remap import main as cluster_remap_main from dpacman.data_tasks.split.remap import main as split_remap_main from dpacman.data_tasks.embeddings.dna import main as embed_dna_main +from dpacman.data_tasks.embeddings.protein import main as embed_protein_main + @hydra.main( config_path=str(root / "configs"), config_name="preprocess", version_base="1.3" @@ -59,7 +61,7 @@ def main(cfg: DictConfig): cluster_remap_main(cfg) else: raise ValueError(f"No clean pipeline defined for: {task_name}") - + # Split elif task_type == "split": if task_name == "remap": @@ -68,11 +70,14 @@ def main(cfg: DictConfig): raise ValueError(f"No clean pipeline defined for: {task_name}") # Embed - elif task_type=="embeddings": + elif task_type == "embeddings": if task_name == "dna": embed_dna_main(cfg) + elif task_name == "protein": + embed_protein_main(cfg) else: raise ValueError(f"No clean pipeline defined for: {task_name}") + # Unknown - error else: raise ValueError(f"Unknown task type: {task_type}") diff --git a/dpacman/scripts/run_embeddings.sh b/dpacman/scripts/run_embeddings.sh index 854e4f585715225012abad5af513e7d4a5a7f08f..a9912aa36396ef4aa2f2be78c0f69459fb88cbbc 100644 --- a/dpacman/scripts/run_embeddings.sh +++ b/dpacman/scripts/run_embeddings.sh @@ -10,7 +10,8 @@ mkdir -p "$run_dir" nohup python -u -m scripts.preprocess \ hydra.run.dir="${run_dir}" \ - data_task="${data_task_type}/dna" \ + data_task="${data_task_type}/protein" \ + data_task.debug="false" \ > "${run_dir}/run.log" 2>&1 & echo $! > "${run_dir}/pid.txt" diff --git a/dpacman/scripts/run_train.sh b/dpacman/scripts/run_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..d3d281b931f21cf0a53428dda9864fd802d20841 --- /dev/null +++ b/dpacman/scripts/run_train.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# Manually specify values used in the config +main_task="train" +model_type="classifier" +timestamp=$(date "+%Y-%m-%d_%H-%M-%S") + +run_dir="$HOME/DPACMAN/logs/${main_task}/${model_type}/runs/${timestamp}" +mkdir -p "$run_dir" + +if [ -z "$WANDB_API_KEY" ]; then + read -s -p "Enter your WANDB API key: " wandb_key + echo + export WANDB_API_KEY="$wandb_key" +fi + +nohup python -u -m scripts.train \ + hydra.run.dir="${run_dir}" \ + > "${run_dir}/run.log" 2>&1 & + +echo $! > "${run_dir}/pid.txt" diff --git a/dpacman/scripts/train.py b/dpacman/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..199ec6261e6b31ee3c498d3314ceb37e7cb1e0c7 --- /dev/null +++ b/dpacman/scripts/train.py @@ -0,0 +1,131 @@ +from typing import Any, Dict, List, Optional, Tuple + +import hydra +import torch +import rootutils +import lightning as L +from lightning import Callback, LightningDataModule, LightningModule, Trainer +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from dpacman.utils import ( + RankedLogger, + extras, + get_metric_value, + instantiate_callbacks, + instantiate_loggers, + log_hyperparameters, + task_wrapper, +) + +log = RankedLogger(__name__, rank_zero_only=True) + + +def h100_settings(): + # Use TensorFloat-32 for float32 matmuls → big speedup with tiny accuracy tradeoff + torch.set_float32_matmul_precision("high") # or "medium" for even more speed + + # (optional; older PyTorch toggle) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + +@task_wrapper +def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """trains model given checkpoint on a datamodule train set. + + This method is wrapped in optional @task_wrapper decorator, that controls the behavior during + failure. Useful for multiruns, saving info about the crash, etc. + + :param cfg: DictConfig configuration composed by Hydra. + :return: Tuple[dict, dict] with metrics and dict with all instantiated objects. + """ + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=True) + + log.info(f"Instantiating datamodule <{cfg.data_module._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data_module) + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating callbacks...") + callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) + + log.info("Instantiating loggers...") + logger: List[Logger] = instantiate_loggers(cfg.get("logger")) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate( + cfg.trainer, callbacks=callbacks, logger=logger + ) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + log_hyperparameters(object_dict) + + if cfg.get("train"): + log.info("Starting training!") + trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + + train_metrics = trainer.callback_metrics + + log.info("Training completed! Ready for testing.") + + if cfg.get("test"): + log.info("Starting testing!") + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for testing...") + ckpt_path = None + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + log.info(f"Best ckpt path: {ckpt_path}") + + test_metrics = trainer.callback_metrics + + # merge train and test metrics + metric_dict = {**train_metrics, **test_metrics} + + return metric_dict, object_dict + + +@hydra.main( + version_base="1.3", config_path=str(root / "configs"), config_name="train.yaml" +) +def main(cfg: DictConfig) -> None: + """Main entry point for evaluation. + + :param cfg: DictConfig configuration composed by Hydra. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + extras(cfg) + + h100_settings() # try using settings for faster h100s training + + # train the model + metric_dict, _ = train(cfg) + + # safely retrieve metric value for hydra-based hyperparameter optimization + metric_value = get_metric_value( + metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") + ) + + # return optimized metric + return metric_value + + +if __name__ == "__main__": + main() diff --git a/dpacman/utils/__init__.py b/dpacman/utils/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..c172bef53e62c49c9b904973030e28ac670beb49 100644 --- a/dpacman/utils/__init__.py +++ b/dpacman/utils/__init__.py @@ -0,0 +1,11 @@ +from dpacman.utils.instantiators import instantiate_callbacks, instantiate_loggers +from dpacman.utils.logging_utils import log_hyperparameters +from dpacman.utils.pylogger import RankedLogger +from dpacman.utils.rich_utils import enforce_tags, print_config_tree +from dpacman.utils.utils import extras, get_metric_value, task_wrapper +from dpacman.utils.plotting_utils import ( + set_font, + load_esm2_type, + get_esm_embeddings, + get_one_hot_embeddings, +) diff --git a/dpacman/utils/clustering.py b/dpacman/utils/clustering.py index cdcfba10e99e795376c12bf54dfbb36d5470ece1..5a217355f47fd70edf71319a0ba459fb30eab5e0 100644 --- a/dpacman/utils/clustering.py +++ b/dpacman/utils/clustering.py @@ -4,61 +4,68 @@ import subprocess import sys from Bio import SeqIO import shutil - + import rootutils import logging logger = logging.getLogger(__name__) + def ensure_mmseqs_in_path(mmseqs_dir): """ Checks if MMseqs2 is in the PATH. If it's not, add it. MMseqs2 will not run if this is not done correctly. - + Args: mmseqs_dir (str): Directory containing MMseqs2 binaries """ - mmseqs_bin = os.path.join(mmseqs_dir, 'mmseqs') - + mmseqs_bin = os.path.join(mmseqs_dir, "mmseqs") + # Check if mmseqs is already in PATH - if shutil.which('mmseqs') is None: + if shutil.which("mmseqs") is None: # Export the MMseqs2 directory to PATH - os.environ['PATH'] = f"{mmseqs_dir}:{os.environ['PATH']}" + os.environ["PATH"] = f"{mmseqs_dir}:{os.environ['PATH']}" logger.info(f"\tAdded {mmseqs_dir} to PATH") + def process_fasta(fasta_path): - fasta_sequences = SeqIO.parse(open(fasta_path),'fasta') + fasta_sequences = SeqIO.parse(open(fasta_path), "fasta") d = {} for fasta in fasta_sequences: id, sequence = fasta.id, str(fasta.seq) d[id] = sequence - + return d + def analyze_clustering_result(input_fasta: str, tsv_path: str): """ Args: input_fasta (str): path to input fasta file """ - + # Process input fasta input_d = process_fasta(input_fasta) - + # Process clusters.tsv - clusters = pd.read_csv(f'{tsv_path}',sep='\t',header=None) - clusters = clusters.rename(columns={ - 0: 'representative seq_id', - 1: 'member seq_id' - }) - - clusters['representative seq'] = clusters['representative seq_id'].apply(lambda seq_id: input_d[seq_id]) - clusters['member seq'] = clusters['member seq_id'].apply(lambda seq_id: input_d[seq_id]) - + clusters = pd.read_csv(f"{tsv_path}", sep="\t", header=None) + clusters = clusters.rename(columns={0: "representative seq_id", 1: "member seq_id"}) + + clusters["representative seq"] = clusters["representative seq_id"].apply( + lambda seq_id: input_d[seq_id] + ) + clusters["member seq"] = clusters["member seq_id"].apply( + lambda seq_id: input_d[seq_id] + ) + # Sort them so that splitting results are reproducible - clusters = clusters.sort_values(by=['representative seq_id','member seq_id'],ascending=True).reset_index(drop=True) - + clusters = clusters.sort_values( + by=["representative seq_id", "member seq_id"], ascending=True + ).reset_index(drop=True) + return clusters + def make_fasta(sequences: dict, fasta_path: str): """ Makes a fasta file from sequences, where the key is the header and the value is the sequence. @@ -69,33 +76,45 @@ def make_fasta(sequences: dict, fasta_path: str): Returns: str: The path to the fasta file. """ - with open(fasta_path, 'w') as f: - for header, sequence in sequences.items(): - f.write(f'>{header}\n{sequence}\n') - + with open(fasta_path, "w") as f: + for header, sequence in sequences.items(): + f.write(f">{header}\n{sequence}\n") + return fasta_path -def run_mmseqs_clustering(input_fasta, output_dir, min_seq_id=0.3, c=0.8, cov_mode=0, cluster_mode=0, path_to_mmseqs='fuson_plm/mmseqs', dbtype=1): + +def run_mmseqs_clustering( + input_fasta, + output_dir, + min_seq_id=0.3, + c=0.8, + cov_mode=0, + cluster_mode=0, + path_to_mmseqs="fuson_plm/mmseqs", + dbtype=1, +): """ Runs MMSeqs2 clustering using easycluster module - + Args: - input_fasta (str): path to input fasta file, formatted >header\nsequence\n>header\nsequence.... + input_fasta (str): path to input fasta file, formatted >header\nsequence\n>header\nsequence.... output_dir (str): path to output dir for clustering results min_seq_id (float): number [0,1] representing --min-seq-id in cluster command c (float): nunber [0,1] representing -c in cluster command - cov_mode (int): number 0, 1, 2, or 3 representing --cov-mode in cluster command + cov_mode (int): number 0, 1, 2, or 3 representing --cov-mode in cluster command cluster_mode (int): number 0, 1, or 2 representing --cluster-mode in cluster command - + """ # Get mmseqs dir logger.info("\nRunning MMSeqs clustering...") - mmseqs_dir = os.path.join(path_to_mmseqs[0:path_to_mmseqs.index('/mmseqs')], 'mmseqs/bin') + mmseqs_dir = os.path.join( + path_to_mmseqs[0 : path_to_mmseqs.index("/mmseqs")], "mmseqs/bin" + ) logger.info(f"Running mmseqs clustering from {mmseqs_dir}") # Ensure MMseqs2 is in the PATH ensure_mmseqs_in_path(mmseqs_dir) - + # Define paths for MMseqs2 mmseqs_bin = "mmseqs" # Ensure this is in your PATH or provide the full path to mmseqs binary @@ -104,41 +123,81 @@ def run_mmseqs_clustering(input_fasta, output_dir, min_seq_id=0.3, c=0.8, cov_mo # Run MMseqs2 easy-cluster cmd_easy_cluster = [ - mmseqs_bin, "easy-cluster", input_fasta, os.path.join(output_dir, "mmseqs"), output_dir, - "--min-seq-id", str(min_seq_id), - "-c", str(c), - "--cov-mode", str(cov_mode), - "--cluster-mode", str(cluster_mode), - "--dbtype", str(dbtype) + mmseqs_bin, + "easy-cluster", + input_fasta, + os.path.join(output_dir, "mmseqs"), + output_dir, + "--min-seq-id", + str(min_seq_id), + "-c", + str(c), + "--cov-mode", + str(cov_mode), + "--cluster-mode", + str(cluster_mode), + "--dbtype", + str(dbtype), ] # Write the command to a log file logger.info("\n\tCommand entered to MMSeqs2:") logger.info("\t" + " ".join(cmd_easy_cluster) + "\n") - + subprocess.run(cmd_easy_cluster, check=True) logger.info(f"Clustering completed. Results are in {output_dir}") - + + def cluster_summary(clusters: pd.DataFrame): """ Summarizes how many clusters were formed, how big they are, etc ... """ - grouped_clusters = clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) - assert len(grouped_clusters) == len(clusters['representative seq_id'].unique()) # make sure number of cluster reps = # grouped clusters - - total_seqs = sum(grouped_clusters['member count']) + grouped_clusters = ( + clusters.groupby("representative seq_id")["member seq_id"] + .count() + .reset_index() + .rename(columns={"member seq_id": "member count"}) + ) + assert len(grouped_clusters) == len( + clusters["representative seq_id"].unique() + ) # make sure number of cluster reps = # grouped clusters + + total_seqs = sum(grouped_clusters["member count"]) logger.info(f"Created {len(grouped_clusters)} clusters of {total_seqs} sequences") - logger.info(f"\t{len(grouped_clusters.loc[grouped_clusters['member count']==1])} clusters of size 1") - csize1_seqs = sum(grouped_clusters[grouped_clusters['member count']==1]['member count']) - logger.info(f"\t\tsequences: {csize1_seqs} ({round(100*csize1_seqs/total_seqs, 2)}%)") - - logger.info(f"\t{len(grouped_clusters.loc[grouped_clusters['member count']>1])} clusters of size > 1") - csizeg1_seqs = sum(grouped_clusters[grouped_clusters['member count']>1]['member count']) - logger.info(f"\t\tsequences: {csizeg1_seqs} ({round(100*csizeg1_seqs/total_seqs, 2)}%)") + logger.info( + f"\t{len(grouped_clusters.loc[grouped_clusters['member count']==1])} clusters of size 1" + ) + csize1_seqs = sum( + grouped_clusters[grouped_clusters["member count"] == 1]["member count"] + ) + logger.info( + f"\t\tsequences: {csize1_seqs} ({round(100*csize1_seqs/total_seqs, 2)}%)" + ) + + logger.info( + f"\t{len(grouped_clusters.loc[grouped_clusters['member count']>1])} clusters of size > 1" + ) + csizeg1_seqs = sum( + grouped_clusters[grouped_clusters["member count"] > 1]["member count"] + ) + logger.info( + f"\t\tsequences: {csizeg1_seqs} ({round(100*csizeg1_seqs/total_seqs, 2)}%)" + ) logger.info(f"\tlargest cluster: {max(grouped_clusters['member count'])}") logger.info("\nCluster size breakdown below...") - - value_counts = grouped_clusters['member count'].value_counts().reset_index().rename(columns={'member count':'cluster size (n_members)','count': 'n_clusters'}) - logger.info(value_counts.sort_values(by='cluster size (n_members)',ascending=True).to_string(index=False)) \ No newline at end of file + + value_counts = ( + grouped_clusters["member count"] + .value_counts() + .reset_index() + .rename( + columns={"member count": "cluster size (n_members)", "count": "n_clusters"} + ) + ) + logger.info( + value_counts.sort_values( + by="cluster size (n_members)", ascending=True + ).to_string(index=False) + ) diff --git a/dpacman/utils/instantiators.py b/dpacman/utils/instantiators.py new file mode 100644 index 0000000000000000000000000000000000000000..aa88a5d060e57e973d3f8e6ba21807b0bb971527 --- /dev/null +++ b/dpacman/utils/instantiators.py @@ -0,0 +1,56 @@ +from typing import List + +import hydra +from lightning import Callback +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from dpacman.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config. + + :param callbacks_cfg: A DictConfig object containing callback configurations. + :return: A list of instantiated callbacks. + """ + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config. + + :param logger_cfg: A DictConfig object containing logger configurations. + :return: A list of instantiated loggers. + """ + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/dpacman/utils/logging_utils.py b/dpacman/utils/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1a1f5b832fde5695c0e3968e42e252774f4dfed7 --- /dev/null +++ b/dpacman/utils/logging_utils.py @@ -0,0 +1,57 @@ +from typing import Any, Dict + +from lightning_utilities.core.rank_zero import rank_zero_only +from omegaconf import OmegaConf + +from dpacman.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +@rank_zero_only +def log_hyperparameters(object_dict: Dict[str, Any]) -> None: + """Controls which config parts are saved by Lightning loggers. + + Additionally saves: + - Number of model parameters + + :param object_dict: A dictionary containing the following objects: + - `"cfg"`: A DictConfig object containing the main config. + - `"model"`: The Lightning model. + - `"trainer"`: The Lightning trainer. + """ + hparams = {} + + cfg = OmegaConf.to_container(object_dict["cfg"]) + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + hparams["data_module"] = cfg["data_module"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/dpacman/utils/models.py b/dpacman/utils/models.py index 717a7be0e4bf23851c454aff8c58dbab3fef8dc1..5a293378f4f1e0bc4720787f723d319d2d226915 100644 --- a/dpacman/utils/models.py +++ b/dpacman/utils/models.py @@ -1,11 +1,13 @@ """ Model-related utilities, such as setting seed """ + import torch import numpy as np import random import os + def set_seed(seed: int = 42) -> None: np.random.seed(seed) random.seed(seed) diff --git a/dpacman/utils/plotting_utils.py b/dpacman/utils/plotting_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..93ab0bed8d445492681f3660e7fd86dd6ba53fd0 --- /dev/null +++ b/dpacman/utils/plotting_utils.py @@ -0,0 +1,260 @@ +# Plotting utils, such as setting ubuntu font +import matplotlib.pyplot as plt +import matplotlib.font_manager as fm +from matplotlib.font_manager import FontProperties +import pickle +import torch +from transformers import EsmModel, AutoTokenizer +import logging +import os +import numpy as np + + +def set_font(): + # Load and set the font + # Get the directory where this script lives + utils_dir = os.path.dirname(os.path.abspath(__file__)) + font_dir = os.path.join(utils_dir, "ubuntu_font") # adjust as needed + + # Paths for regular, bold, italic fonts + regular_font_path = os.path.join(font_dir, "Ubuntu-Regular.ttf") + bold_font_path = os.path.join(font_dir, "Ubuntu-Bold.ttf") + italic_font_path = os.path.join(font_dir, "Ubuntu-Italic.ttf") + bold_italic_font_path = os.path.join(font_dir, "Ubuntu-BoldItalic.ttf") + + # Load the font properties + regular_font = FontProperties(fname=regular_font_path) + bold_font = FontProperties(fname=bold_font_path) + italic_font = FontProperties(fname=italic_font_path) + bold_italic_font = FontProperties(fname=bold_italic_font_path) + + # Add the fonts to the font manager + fm.fontManager.addfont(regular_font_path) + fm.fontManager.addfont(bold_font_path) + fm.fontManager.addfont(italic_font_path) + fm.fontManager.addfont(bold_italic_font_path) + + # Set the font family globally to Ubuntu + plt.rcParams["font.family"] = regular_font.get_name() + + # Set the font family globally to Ubuntu + plt.rcParams["font.family"] = regular_font.get_name() + plt.rcParams["mathtext.fontset"] = "custom" + plt.rcParams["mathtext.rm"] = regular_font.get_name() + plt.rcParams["mathtext.it"] = italic_font.get_name() + plt.rcParams["mathtext.bf"] = bold_font.get_name() + + +def redump_pickle_dictionary(pickle_path): + """ + Loads a pickle dictionary and redumps it in its location. This allows a clean reset for a pickle built with 'ab+' + """ + entries = {} + # Load one by one + with open(pickle_path, "rb") as f: + while True: + try: + entry = pickle.load(f) + entries.update(entry) + except EOFError: + break # End of file reached + except Exception as e: + print(f"An error occurred: {e}") + break + # Redump + with open(pickle_path, "wb") as f: + pickle.dump(entries, f) + + +def load_esm2_type(esm_type, device=None): + """ + Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D) + """ + # Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings + logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) + + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + model = EsmModel.from_pretrained(f"facebook/{esm_type}") + tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}") + + model.to(device) + model.eval() # disables dropout for deterministic results + + return model, tokenizer, device + + +def get_esm_embeddings( + model, + tokenizer, + sequences, + device, + average=True, + print_updates=False, + savepath=None, + save_at_end=False, + max_length=None, +): + """ + Compute ESM embeddings. + + Args: + model + tokenizer + sequences + device + average: if True, the average embeddings will be taken and returned + savepath: if savepath is not None, the embeddings will be saved somewhere. It must be a pickle + """ + # Correct save path to pickle if necessary + if savepath is not None: + if savepath[-4::] != ".pkl": + savepath += ".pkl" + + # If no max length was passed, just set it to the maximum in the dataset + max_seq_len = max([len(s) for s in sequences]) + if max_length is None: + max_length = max_seq_len + 2 # +2 for BOS, EOS + + # Initialize an empty dict to store the ESM embeddings + embedding_dict = {} + # Iterate through the seqs + for i in range(len(sequences)): + sequence = sequences[i] + # Get the embeddings + with torch.no_grad(): + inputs = tokenizer( + sequence, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_length, + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + + outputs = model(**inputs) + embedding = outputs.last_hidden_state + + # remove extra dimension + embedding = embedding.squeeze(0) + # remove BOS and EOS tokens + embedding = embedding[1:-1, :] + + # Convert embeddings to numpy array (if needed) + embedding = embedding.cpu().numpy() + + # Average (if necessary) + if average: + embedding = embedding.mean(0) + + # Add to dictionary + embedding_dict[sequence] = embedding + + # Save individual embedding (if necessary) + if not (savepath is None) and not (save_at_end): + with open(savepath, "ab+") as f: + d = {sequence: embedding} + pickle.dump(d, f) + + # Print update (if necessary) + if print_updates: + print(f"sequence {i+1}: {sequence[0:10]}...") + + # Dump all at once at the end (if necessary) + if not (savepath is None): + # If saving for the first time, just dump it + if save_at_end: + with open(savepath, "wb") as f: + pickle.dump(embedding_dict, f) + # If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely + else: + redump_pickle_dictionary(savepath) + + # Return the dictionary + return embedding_dict + + +def one_hot_encode_sequence(seq, add=True, max_len=None): + """ + One-hot encode a single protein sequence. Pads or truncates to max_len. + + Parameters: + - seq: protein sequence (string of amino acids) + - max_len: desired fixed length (pads with zeros or truncates) + + Returns: + - 2D numpy array of shape (max_len, 20) + """ + AA_ORDER = list("ACDEFGHIKLMNPQRSTVWY") + AA_TO_IDX = {aa: i for i, aa in enumerate(AA_ORDER)} + + if max_len is None: + max_len = len(seq) + + one_hot = np.zeros((max_len, len(AA_ORDER)), dtype=np.float32) + for i, aa in enumerate(seq[:max_len]): + if aa in AA_TO_IDX: + one_hot[i, AA_TO_IDX[aa]] = 1.0 + + # Add (if necessary) + if add: + # add across length dimension + one_hot = np.sum(one_hot, axis=0) + + return one_hot + + +def get_one_hot_embeddings( + sequences, + print_updates=False, + savepath=None, + add=True, + save_at_end=False, + max_length=None, +): + """ + Compute One Hot embeddings + """ + # Correct save path to pickle if necessary + if savepath is not None: + if savepath[-4::] != ".pkl": + savepath += ".pkl" + + # If no max length was passed, just set it to the maximum in the dataset + max_seq_len = max([len(s) for s in sequences]) + if max_length is None: + max_length = max_seq_len + 2 # +2 for BOS, EOS + + # Initialize an empty dict to store the ESM embeddings + embedding_dict = {} + # Iterate through the seqs + for i in range(len(sequences)): + sequence = sequences[i] + embedding = one_hot_encode_sequence(sequence, add=add, max_len=None) + # Add to dictionary + embedding_dict[sequence] = embedding + + # Save individual embedding (if necessary) + if not (savepath is None) and not (save_at_end): + with open(savepath, "ab+") as f: + d = {sequence: embedding} + pickle.dump(d, f) + + # Print update (if necessary) + if print_updates: + print(f"sequence {i+1}: {sequence[0:10]}...") + + # Dump all at once at the end (if necessary) + if not (savepath is None): + # If saving for the first time, just dump it + if save_at_end: + with open(savepath, "wb") as f: + pickle.dump(embedding_dict, f) + # If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely + else: + redump_pickle_dictionary(savepath) + + # Return the dictionary + return embedding_dict diff --git a/dpacman/utils/pylogger.py b/dpacman/utils/pylogger.py new file mode 100644 index 0000000000000000000000000000000000000000..31a76c376b8da6212f37b148bdd7182f5b0ce553 --- /dev/null +++ b/dpacman/utils/pylogger.py @@ -0,0 +1,55 @@ +import logging +from typing import Mapping, Optional + +from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only + + +class RankedLogger(logging.LoggerAdapter): + """A multi-GPU-friendly python command line logger.""" + + def __init__( + self, + name: str = __name__, + rank_zero_only: bool = False, + extra: Optional[Mapping[str, object]] = None, + ) -> None: + """Initializes a multi-GPU-friendly python command line logger that logs on all processes + with their rank prefixed in the log message. + + :param name: The name of the logger. Default is ``__name__``. + :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. + :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. + """ + logger = logging.getLogger(name) + super().__init__(logger=logger, extra=extra) + self.rank_zero_only = rank_zero_only + + def log( + self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs + ) -> None: + """Delegate a log call to the underlying logger, after prefixing its message with the rank + of the process it's being logged from. If `'rank'` is provided, then the log will only + occur on that rank/process. + + :param level: The level to log at. Look at `logging.__init__.py` for more information. + :param msg: The message to log. + :param rank: The rank to log at. + :param args: Additional args to pass to the underlying logging function. + :param kwargs: Any additional keyword args to pass to the underlying logging function. + """ + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + current_rank = getattr(rank_zero_only, "rank", None) + if current_rank is None: + raise RuntimeError( + "The `rank_zero_only.rank` needs to be set before use" + ) + msg = rank_prefixed_message(msg, current_rank) + if self.rank_zero_only: + if current_rank == 0: + self.logger.log(level, msg, *args, **kwargs) + else: + if rank is None: + self.logger.log(level, msg, *args, **kwargs) + elif current_rank == rank: + self.logger.log(level, msg, *args, **kwargs) diff --git a/dpacman/utils/rich_utils.py b/dpacman/utils/rich_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b53fe1d381e3f2cecf9b16c80b3f20cf5d5592ad --- /dev/null +++ b/dpacman/utils/rich_utils.py @@ -0,0 +1,103 @@ +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from lightning_utilities.core.rank_zero import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict +from rich.prompt import Prompt + +from dpacman.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data_module", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints the contents of a DictConfig as a tree structure using the Rich library. + + :param cfg: A DictConfig composed by Hydra. + :param print_order: Determines in what order config components are printed. Default is ``("data", "model", + "callbacks", "logger", "trainer", "paths", "extras")``. + :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. + :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. + """ + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + ( + queue.append(field) + if field in cfg + else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ) + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config. + + :param cfg: A DictConfig composed by Hydra. + :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. + """ + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/dpacman/utils/utils.py b/dpacman/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2f2f053900b2d50c6a43872476ddf0cc08cf6885 --- /dev/null +++ b/dpacman/utils/utils.py @@ -0,0 +1,121 @@ +import warnings +from importlib.util import find_spec +from typing import Any, Callable, Dict, Optional, Tuple + +from omegaconf import DictConfig + +from dpacman.utils import pylogger, rich_utils + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + + :param cfg: A DictConfig object containing the config tree. + """ + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. + + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + ... + return metric_dict, object_dict + ``` + + :param task_func: The task function to be wrapped. + + :return: The wrapped task function. + """ + + def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + # execute the task + try: + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # some hyperparameter combinations might be invalid or cause out-of-memory errors + # so when using hparam search plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.output_dir}") + + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + return metric_dict, object_dict + + return wrap + + +def get_metric_value( + metric_dict: Dict[str, Any], metric_name: Optional[str] +) -> Optional[float]: + """Safely retrieves value of the metric logged in LightningModule. + + :param metric_dict: A dict containing metric values. + :param metric_name: If provided, the name of the metric to retrieve. + :return: If a metric name was provided, the value of the metric. + """ + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value diff --git a/h100_env.yaml b/h100_env.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ca2d5689945317789e8bf1b45dce87efe6e5f750 --- /dev/null +++ b/h100_env.yaml @@ -0,0 +1,63 @@ +# reasons you might want to use `environment.yaml` instead of `requirements.txt`: +# - pip installs packages in a loop, without ensuring dependencies across all packages +# are fulfilled simultaneously, but conda achieves proper dependency control across +# all packages +# - conda allows for installing packages without requiring certain compilers or +# libraries to be available in the system, since it installs precompiled binaries +# in case of errors look here: https://pytorch.org/get-started/previous-versions/ + +name: dnabind2 + +channels: + - conda-forge + - defaults + - nvidia # GH200 + - pytorch + +# it is strongly recommended to specify versions of packages installed through conda +# to avoid situation when version-unspecified packages install their latest major +# versions which can sometimes break things + +# current approach below keeps the dependencies in the same major versions across all +# users, but allows for different minor and patch versions of packages where backwards +# compatibility is usually guaranteed + +dependencies: + - python=3.10 + - dask[complete] + - pip>=23 + - lightning=2.5.1 + - cudnn=9.10.2.21 + - torchmetrics=0.11.4 + - pip: + - torch==2.6.0+cu124 + - rootutils==1.0.7 + - hydra-core==1.3.2 # Hydra for config management + - hydra-colorlog==1.2.0 # Allow colorful logging in Hydra + - omegaconf==2.3.0 # Required by hydra-core + - pandas==2.2.3 + - lxml==5.3.0 + - pymex==0.9.31 + - gitpython==3.1.44 + - black==25.1.0 # code formatter + - tqdm==4.67.1 + - matplotlib==3.10.3 + - transformers==4.55.2 + - biopython==1.85 + - ortools==9.14.6206 + - fair-esm==2.0.0 + - scikit-learn==1.7.1 + - rich==14.1.0 + - wandb==0.21.1 + - --extra-index-url https://download.pytorch.org/whl/cu124 + - -e . + +# conda install -c nvidia -c conda-forge cuda-toolkit=12.4 ninja cmake -y +# use the toolkit inside the conda env +#export CUDA_HOME="$CONDA_PREFIX" +#export PATH="$CUDA_HOME/bin:$PATH" +#export LD_LIBRARY_PATH="$CUDA_HOME/lib64:$LD_LIBRARY_PATH" +# recommended by many CUDA builds +#export CUDACXX="$CUDA_HOME/bin/nvcc" +#which nvcc && nvcc -V # should now show 12.4 under $CONDA_PREFIX/bin/nvcc +#pip install --no-build-isolation mamba_ssm \ No newline at end of file