svincoff commited on
Commit
29899b4
·
1 Parent(s): bc0d37c

training works

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +6 -1
  2. configs/callbacks/default.yaml +21 -0
  3. configs/callbacks/early_stopping.yaml +15 -0
  4. configs/callbacks/model_checkpoint.yaml +17 -0
  5. configs/callbacks/model_summary.yaml +5 -0
  6. dpacman/classifier/model/__init__.py → configs/callbacks/none.yaml +0 -0
  7. configs/callbacks/rich_progress_bar.yaml +4 -0
  8. configs/data_module/pair.yaml +13 -0
  9. configs/data_modules/pair.yaml +0 -9
  10. configs/data_task/cluster/remap.yaml +1 -1
  11. configs/data_task/download/genome.yaml +1 -1
  12. configs/data_task/download/remap.yaml +1 -1
  13. configs/data_task/embeddings/dna.yaml +8 -4
  14. configs/data_task/embeddings/protein.yaml +14 -0
  15. configs/data_task/fimo/post_fimo.yaml +1 -1
  16. configs/data_task/fimo/pre_fimo.yaml +1 -1
  17. configs/data_task/fimo/run_fimo.yaml +1 -1
  18. configs/data_task/split/remap.yaml +3 -0
  19. configs/extras/default.yaml +8 -0
  20. configs/logger/aim.yaml +28 -0
  21. configs/logger/comet.yaml +12 -0
  22. configs/logger/csv.yaml +7 -0
  23. configs/logger/many_loggers.yaml +9 -0
  24. configs/logger/mlflow.yaml +12 -0
  25. configs/logger/neptune.yaml +9 -0
  26. configs/logger/tensorboard.yaml +10 -0
  27. configs/logger/wandb.yaml +16 -0
  28. configs/model/classifier.yaml +9 -0
  29. configs/{models → model}/pooling/truncatedsvd.yaml +0 -0
  30. configs/models/classifier.yaml +0 -11
  31. configs/preprocess.yaml +1 -1
  32. configs/train.yaml +37 -2
  33. configs/trainer/cpu.yaml +5 -0
  34. configs/trainer/ddp.yaml +9 -0
  35. configs/trainer/ddp_sim.yaml +7 -0
  36. configs/trainer/default.yaml +19 -0
  37. configs/trainer/gpu.yaml +5 -0
  38. configs/trainer/mps.yaml +5 -0
  39. dpacman/classifier/loss.py +58 -0
  40. dpacman/classifier/model.py +258 -0
  41. dpacman/classifier/model/clustering_data.py +0 -383
  42. dpacman/classifier/model/compress_embeddings.py +0 -54
  43. dpacman/classifier/model/compute_embeddings.py +0 -560
  44. dpacman/classifier/model/extract_tf_symbols.py +0 -27
  45. dpacman/classifier/model/loss.py +0 -34
  46. dpacman/classifier/model/make_pair_list.py +0 -220
  47. dpacman/classifier/model/make_peak_fasta.py +0 -13
  48. dpacman/classifier/model_tmp/clustering_data.py +139 -47
  49. dpacman/classifier/model_tmp/compress_embeddings.py +15 -7
  50. dpacman/classifier/model_tmp/compute_embeddings.py +125 -59
.gitignore CHANGED
@@ -29,4 +29,9 @@ dpacman/nohup.out
29
  dpacman/*/__pycache__/
30
  dpacman/data_tasks/split/__pycache__/
31
  dpacman/data_tasks/cluster/__pycache__/
32
- dpacman/data_tasks/embeddings/__pycache__/
 
 
 
 
 
 
29
  dpacman/*/__pycache__/
30
  dpacman/data_tasks/split/__pycache__/
31
  dpacman/data_tasks/cluster/__pycache__/
32
+ dpacman/data_tasks/embeddings/__pycache__/
33
+ dpacman/combine_shards.py
34
+ dpacman/combine.log
35
+ dpacman/loss_sim.py
36
+ dpacman/loss_temp.py
37
+ dpacman/peak_examples/
configs/callbacks/default.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model_checkpoint
3
+ - early_stopping
4
+ - model_summary
5
+ - _self_
6
+
7
+ model_checkpoint:
8
+ dirpath: ${paths.output_dir}/checkpoints
9
+ filename: "epoch_{epoch:03d}"
10
+ monitor: "val/loss"
11
+ mode: "min"
12
+ save_last: True
13
+ auto_insert_metric_name: False
14
+
15
+ early_stopping:
16
+ monitor: "val/loss"
17
+ patience: 100
18
+ mode: "min"
19
+
20
+ model_summary:
21
+ max_depth: -1
configs/callbacks/early_stopping.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
2
+
3
+ early_stopping:
4
+ _target_: lightning.pytorch.callbacks.EarlyStopping
5
+ monitor: ??? # quantity to be monitored, must be specified !!!
6
+ min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
7
+ patience: 3 # number of checks with no improvement after which training will be stopped
8
+ verbose: False # verbosity mode
9
+ mode: "min" # "max" means higher metric value is better, can be also "min"
10
+ strict: True # whether to crash the training if monitor is not found in the validation metrics
11
+ check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
12
+ stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
13
+ divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
14
+ check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
15
+ # log_rank_zero_only: False # this keyword argument isn't available in stable version
configs/callbacks/model_checkpoint.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
2
+
3
+ model_checkpoint:
4
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
5
+ dirpath: null # directory to save the model file
6
+ filename: null # checkpoint filename
7
+ monitor: null # name of the logged metric which determines when model is improving
8
+ verbose: False # verbosity mode
9
+ save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
10
+ save_top_k: 1 # save k best models (determined by above metric)
11
+ mode: "min" # "max" means higher metric value is better, can be also "min"
12
+ auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
13
+ save_weights_only: False # if True, then only the model’s weights will be saved
14
+ every_n_train_steps: null # number of training steps between checkpoints
15
+ train_time_interval: null # checkpoints are monitored at the specified time interval
16
+ every_n_epochs: null # number of epochs between checkpoints
17
+ save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
configs/callbacks/model_summary.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
2
+
3
+ model_summary:
4
+ _target_: lightning.pytorch.callbacks.RichModelSummary
5
+ max_depth: 1 # the maximum depth of layer nesting that the summary will include
dpacman/classifier/model/__init__.py → configs/callbacks/none.yaml RENAMED
File without changes
configs/callbacks/rich_progress_bar.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
2
+
3
+ rich_progress_bar:
4
+ _target_: lightning.pytorch.callbacks.RichProgressBar
configs/data_module/pair.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: dpacman.data_modules.pair.PairDataModule
2
+
3
+ train_file: data_files/processed/splits/by_dna/babytrain.csv
4
+ val_file: data_files/processed/splits/by_dna/babyval.csv
5
+ test_file: data_files/processed/splits/by_dna/babytest.csv
6
+
7
+ tr_shelf_path: data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf
8
+ dna_shelf_path: data_files/processed/embeddings/fimo_hits_only/baby_peaks_segmentnt_pernuc_with_onehot.shelf
9
+
10
+ batch_size: 32
11
+ num_workers: 8
12
+
13
+ maximize_num_workers: False
configs/data_modules/pair.yaml DELETED
@@ -1,9 +0,0 @@
1
-
2
- train_file: data_files/splits/train.csv
3
- val_file: data_files/splits/val.csv
4
- test_file: data_files/splits/test.csv
5
-
6
- batch_size: 32
7
- num_workers: 8
8
-
9
- maximize_num_workers: False
 
 
 
 
 
 
 
 
 
 
configs/data_task/cluster/remap.yaml CHANGED
@@ -1,5 +1,5 @@
1
  name: remap
2
- type: cluster
3
 
4
  max_protein_length: 1998
5
 
 
1
  name: remap
2
+ task_type: cluster
3
 
4
  max_protein_length: 1998
5
 
configs/data_task/download/genome.yaml CHANGED
@@ -1,5 +1,5 @@
1
  name: genome
2
- type: download
3
  output_dir: dpacman/data_files/raw/genomes
4
  genomes:
5
  - hg38
 
1
  name: genome
2
+ task_type: download
3
  output_dir: dpacman/data_files/raw/genomes
4
  genomes:
5
  - hg38
configs/data_task/download/remap.yaml CHANGED
@@ -1,5 +1,5 @@
1
  name: remap
2
- type: download
3
 
4
  nr_url: https://remap.univ-amu.fr/storage/remap2022/hg38/MACS2/remap2022_nr_macs2_hg38_v1_0.bed.gz
5
  nr_output_dir: dpacman/data_files/raw/remap
 
1
  name: remap
2
+ task_type: download
3
 
4
  nr_url: https://remap.univ-amu.fr/storage/remap2022/hg38/MACS2/remap2022_nr_macs2_hg38_v1_0.bed.gz
5
  nr_output_dir: dpacman/data_files/raw/remap
configs/data_task/embeddings/dna.yaml CHANGED
@@ -1,9 +1,13 @@
1
  name: dna
2
- type: embeddings
3
 
4
  genome_json_dir: null
5
- chrom_model: caduceus
6
- input_file: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/dna_seqid_to_dna_sequence.json
7
  out_dir: dpacman/data_files/processed/embeddings/fimo_hits_only
8
 
9
- device: gpu
 
 
 
 
 
1
  name: dna
2
+ task_type: embeddings
3
 
4
  genome_json_dir: null
5
+ chrom_model: segmentnt
6
+ input_file: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/dna_seqid_to_dna_sequence_with_rc.json
7
  out_dir: dpacman/data_files/processed/embeddings/fimo_hits_only
8
 
9
+ device: gpu
10
+
11
+ batch_size: 1
12
+
13
+ debug: false
configs/data_task/embeddings/protein.yaml CHANGED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: protein
2
+ task_type: embeddings
3
+
4
+ prot_model: esm
5
+ input_file: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/tr_seqid_to_tr_sequence.json
6
+ out_dir: dpacman/data_files/processed/embeddings/fimo_hits_only
7
+
8
+ device: gpu
9
+
10
+ save_as_shelf: true
11
+
12
+ batch_size: 1
13
+
14
+ debug: false
configs/data_task/fimo/post_fimo.yaml CHANGED
@@ -1,5 +1,5 @@
1
  name: post_fimo
2
- type: fimo
3
 
4
  fimo_out_dir: dpacman/data_files/processed/fimo/fimo_out_q
5
  processed_output_csv: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/remap2022_crm_fimo_output_q_processed.csv
 
1
  name: post_fimo
2
+ task_type: fimo
3
 
4
  fimo_out_dir: dpacman/data_files/processed/fimo/fimo_out_q
5
  processed_output_csv: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/remap2022_crm_fimo_output_q_processed.csv
configs/data_task/fimo/pre_fimo.yaml CHANGED
@@ -1,5 +1,5 @@
1
  name: pre_fimo
2
- type: fimo
3
 
4
  paths:
5
  input_csv: dpacman/data_files/processed/remap/remap2022_crm_macs2_hg38_v1_0_clean.tsv
 
1
  name: pre_fimo
2
+ task_type: fimo
3
 
4
  paths:
5
  input_csv: dpacman/data_files/processed/remap/remap2022_crm_macs2_hg38_v1_0_clean.tsv
configs/data_task/fimo/run_fimo.yaml CHANGED
@@ -1,5 +1,5 @@
1
  name: run_fimo
2
- type: fimo
3
 
4
  debug: true
5
 
 
1
  name: run_fimo
2
+ task_type: fimo
3
 
4
  debug: true
5
 
configs/data_task/split/remap.yaml CHANGED
@@ -10,7 +10,10 @@ cluster_output_paths:
10
  input_data_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/remap2022_crm_fimo_output_q_processed_seed0.parquet
11
  split_out_dir: dpacman/data_files/processed/splits
12
 
 
 
13
  split_by: both # protein, dna, or both
 
14
 
15
  test_ratio: 0.10
16
  val_ratio: 0.10
 
10
  input_data_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/remap2022_crm_fimo_output_q_processed_seed0.parquet
11
  split_out_dir: dpacman/data_files/processed/splits
12
 
13
+ dna_map_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/dna_seqid_to_dna_sequence.json
14
+
15
  split_by: both # protein, dna, or both
16
+ augment_rc: true
17
 
18
  test_ratio: 0.10
19
  val_ratio: 0.10
configs/extras/default.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # disable python warnings if they annoy you
2
+ ignore_warnings: False
3
+
4
+ # ask user for tags if none are provided in the config
5
+ enforce_tags: True
6
+
7
+ # pretty print config tree at the start of the run using Rich library
8
+ print_config: True
configs/logger/aim.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://aimstack.io/
2
+
3
+ # example usage in lightning module:
4
+ # https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
5
+
6
+ # open the Aim UI with the following command (run in the folder containing the `.aim` folder):
7
+ # `aim up`
8
+
9
+ aim:
10
+ _target_: aim.pytorch_lightning.AimLogger
11
+ repo: ${paths.root_dir} # .aim folder will be created here
12
+ # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
13
+
14
+ # aim allows to group runs under experiment name
15
+ experiment: null # any string, set to "default" if not specified
16
+
17
+ train_metric_prefix: "train/"
18
+ val_metric_prefix: "val/"
19
+ test_metric_prefix: "test/"
20
+
21
+ # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
22
+ system_tracking_interval: 10 # set to null to disable system metrics tracking
23
+
24
+ # enable/disable logging of system params such as installed packages, git info, env vars, etc.
25
+ log_system_params: true
26
+
27
+ # enable/disable tracking console logs (default value is true)
28
+ capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550
configs/logger/comet.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://www.comet.ml
2
+
3
+ comet:
4
+ _target_: lightning.pytorch.loggers.comet.CometLogger
5
+ api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
6
+ save_dir: "${paths.output_dir}"
7
+ project_name: "lightning-hydra-template"
8
+ rest_api_key: null
9
+ # experiment_name: ""
10
+ experiment_key: null # set to resume experiment
11
+ offline: False
12
+ prefix: ""
configs/logger/csv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # csv logger built in lightning
2
+
3
+ csv:
4
+ _target_: lightning.pytorch.loggers.csv_logs.CSVLogger
5
+ save_dir: "${paths.output_dir}"
6
+ name: "csv/"
7
+ prefix: ""
configs/logger/many_loggers.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # train with many loggers at once
2
+
3
+ defaults:
4
+ # - comet
5
+ - csv
6
+ # - mlflow
7
+ # - neptune
8
+ - tensorboard
9
+ - wandb
configs/logger/mlflow.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://mlflow.org
2
+
3
+ mlflow:
4
+ _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
5
+ # experiment_name: ""
6
+ # run_name: ""
7
+ tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
8
+ tags: null
9
+ # save_dir: "./mlruns"
10
+ prefix: ""
11
+ artifact_location: null
12
+ # run_id: ""
configs/logger/neptune.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # https://neptune.ai
2
+
3
+ neptune:
4
+ _target_: lightning.pytorch.loggers.neptune.NeptuneLogger
5
+ api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
6
+ project: username/lightning-hydra-template
7
+ # name: ""
8
+ log_model_checkpoints: True
9
+ prefix: ""
configs/logger/tensorboard.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://www.tensorflow.org/tensorboard/
2
+
3
+ tensorboard:
4
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
5
+ save_dir: "${paths.output_dir}/tensorboard/"
6
+ name: null
7
+ log_graph: False
8
+ default_hp_metric: True
9
+ prefix: ""
10
+ # version: ""
configs/logger/wandb.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://wandb.ai
2
+
3
+ wandb:
4
+ _target_: lightning.pytorch.loggers.wandb.WandbLogger
5
+ # name: "" # name of the run (normally generated by wandb)
6
+ save_dir: "${paths.output_dir}"
7
+ offline: False
8
+ id: null # pass correct id to resume experiment!
9
+ anonymous: null # enable anonymous logging
10
+ project: "dnabind"
11
+ log_model: False # upload lightning ckpts
12
+ prefix: "" # a string to put at the beginning of metric keys
13
+ # entity: "" # set to name of your wandb team
14
+ group: ""
15
+ tags: []
16
+ job_type: ""
configs/model/classifier.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ _target_: dpacman.classifier.model.BindPredictor
2
+
3
+ lr: 1e-4
4
+ alpha: 20
5
+ gamma: 20
6
+ weight_decay: 0.01
7
+
8
+ glm_input_dim: 1029
9
+ compressed_dim: 1029
configs/{models → model}/pooling/truncatedsvd.yaml RENAMED
File without changes
configs/models/classifier.yaml DELETED
@@ -1,11 +0,0 @@
1
- name: classifier
2
- type: train
3
-
4
- params:
5
- epochs: 10
6
- batch_size: 32
7
- lr: 1e-4
8
- seed: 42
9
-
10
- out_dir: null
11
- pair_list: null
 
 
 
 
 
 
 
 
 
 
 
 
configs/preprocess.yaml CHANGED
@@ -6,4 +6,4 @@ defaults:
6
  - hydra: default # ← tells Hydra to use the logging/output config
7
  - data_task: download/genome
8
 
9
- task_name: preprocess/${data_task.type}
 
6
  - hydra: default # ← tells Hydra to use the logging/output config
7
  - data_task: download/genome
8
 
9
+ task_name: preprocess/${data_task.task_type}
configs/train.yaml CHANGED
@@ -2,7 +2,42 @@ defaults:
2
  - _self_
3
  - paths: default
4
  - hydra: default # ← tells Hydra to use the logging/output config
 
 
5
  - trainer: gpu
6
- - data_task: model/classifier
 
 
7
 
8
- task_name: train/${data_task.type}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  - _self_
3
  - paths: default
4
  - hydra: default # ← tells Hydra to use the logging/output config
5
+ - data_module: pair
6
+ - model: classifier
7
  - trainer: gpu
8
+ - extras: default
9
+ - logger: wandb
10
+ - callbacks: default
11
 
12
+ # experiment configs allow for version control of specific hyperparameters
13
+ # e.g. best hyperparameters for given model and datamodule
14
+ - experiment: null
15
+
16
+ # config for hyperparameter optimization
17
+ - hparams_search: null
18
+
19
+ # debugging config (enable through command line, e.g. `python train.py debug=default)
20
+ - debug: null
21
+
22
+ task_name: train/${model}
23
+
24
+ # tags to help you identify your experiments
25
+ # you can overwrite this in experiment configs
26
+ # overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
27
+ tags: ["dev"]
28
+
29
+ # set False to skip model training
30
+ train: True
31
+
32
+ # evaluate on test set, using best model weights achieved during training
33
+ # lightning chooses best weights based on the metric specified in checkpoint callback
34
+ test: True
35
+
36
+ # simply provide checkpoint path to resume training
37
+ ckpt_path: null
38
+
39
+ # seed for random number generators in pytorch, numpy and python.random
40
+ seed: 42
41
+
42
+ trainer:
43
+ max_epochs: 20
configs/trainer/cpu.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ accelerator: cpu
5
+ devices: 1
configs/trainer/ddp.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ strategy: ddp
5
+
6
+ accelerator: gpu
7
+ devices: 4
8
+ num_nodes: 1
9
+ sync_batchnorm: True
configs/trainer/ddp_sim.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ # simulate DDP on CPU, useful for debugging
5
+ accelerator: cpu
6
+ devices: 2
7
+ strategy: ddp_spawn
configs/trainer/default.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: lightning.pytorch.trainer.Trainer
2
+
3
+ default_root_dir: ${paths.output_dir}
4
+
5
+ min_epochs: 1 # prevents early stopping
6
+ max_epochs: 10
7
+
8
+ accelerator: cpu
9
+ devices: 1
10
+
11
+ # mixed precision for extra speed-up
12
+ # precision: 16
13
+
14
+ # perform a validation loop every N training epochs
15
+ check_val_every_n_epoch: 1
16
+
17
+ # set True to to ensure deterministic results
18
+ # makes training slower but gives more reproducibility than just setting seeds
19
+ deterministic: False
configs/trainer/gpu.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ accelerator: gpu
5
+ devices: 1
configs/trainer/mps.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ accelerator: mps
5
+ devices: 1
dpacman/classifier/loss.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Define loss functions needed for training the model
3
+ """
4
+
5
+ import torch
6
+ from torch.nn import functional as F
7
+
8
+
9
+ def bce_loss_masked(logits, targets, nonpeak_mask, pos_weight=None):
10
+ """
11
+ Compute the masked Binary Cross Entropy, only on certain positions.
12
+ We will only compute BCE on positions whre nonpeak_mask == 1.0; the mask represents non-peak positions
13
+ """
14
+ loss = F.binary_cross_entropy_with_logits(
15
+ logits, targets, reduction="none", pos_weight=pos_weight
16
+ )
17
+ denom = nonpeak_mask.sum().clamp_min(1.0)
18
+ return (loss * nonpeak_mask).sum() / denom
19
+
20
+
21
+ def mse_peaks_only(logits, targets, peak_mask, eps=1e-8):
22
+ """
23
+ Calculate MSE on peaks only.
24
+ """
25
+ probs = torch.sigmoid(logits)
26
+ mse_peaks = F.mse_loss(probs * peak_mask, targets * peak_mask, reduction="sum") / (
27
+ peak_mask.sum() + eps
28
+ )
29
+ return mse_peaks
30
+
31
+
32
+ def calculate_loss(logits, targets, eps=1e-8, alpha=1.0, gamma=1.0):
33
+ """
34
+ Combine masked-BCE + global-MSE to get a loss vlaue
35
+ """
36
+ # Calculate peak and non-peak masks.
37
+ # Anything outside a peak will have a label equal to 0.
38
+ nonpeak_mask = (targets == 0).float()
39
+ peak_mask = (targets > 0).float()
40
+
41
+ bce_nonpeak = bce_loss_masked(logits, targets, nonpeak_mask)
42
+ mse_peak = mse_peaks_only(logits, targets, peak_mask, eps=eps)
43
+
44
+ loss = alpha * bce_nonpeak + gamma * mse_peak
45
+
46
+ return loss
47
+
48
+
49
+ def accuracy_percentage(logits, targets, peak_thresh=0.5):
50
+ """
51
+ Compute accuracy in predicting high-confidence peaks (probability > 0.5)
52
+ """
53
+ probs = torch.sigmoid(logits)
54
+ preds_bin = (probs >= 0.5).float()
55
+ labels = (targets >= peak_thresh).float()
56
+ correct = (preds_bin == labels).float().sum()
57
+ total = torch.numel(labels)
58
+ return (correct / max(1, total)).item() * 100.0
dpacman/classifier/model.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Lightning Module for the binding model.
3
+ """
4
+
5
+ import torch
6
+ from torch import nn
7
+ from lightning import LightningModule
8
+ from dpacman.utils.models import set_seed
9
+ from .loss import calculate_loss
10
+
11
+ set_seed()
12
+
13
+
14
+ class LocalCNN(nn.Module):
15
+ def __init__(self, dim: int = 256, kernel_size: int = 3):
16
+ super().__init__()
17
+ padding = kernel_size // 2
18
+ self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
19
+ self.act = nn.GELU()
20
+ self.ln = nn.LayerNorm(dim)
21
+
22
+ def forward(self, x: torch.Tensor):
23
+ # x: (batch, L, dim)
24
+ out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
25
+ out = self.act(out)
26
+ out = out.transpose(1, 2) # → (batch, L, dim)
27
+ return self.ln(out + x) # residual
28
+
29
+
30
+ class CrossModalBlock(nn.Module):
31
+ def __init__(self, dim: int = 256, heads: int = 8):
32
+ super().__init__()
33
+ # self-attention for both sides
34
+ self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
35
+ self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
36
+ self.ln_b1 = nn.LayerNorm(dim)
37
+ self.ln_g1 = nn.LayerNorm(dim)
38
+
39
+ self.ffn_b = nn.Sequential(
40
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
41
+ )
42
+ self.ffn_g = nn.Sequential(
43
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
44
+ )
45
+ self.ln_b2 = nn.LayerNorm(dim)
46
+ self.ln_g2 = nn.LayerNorm(dim)
47
+
48
+ # cross attention (binder queries, glm keys/values)
49
+ # so the NDA path is updated by the transcriptoin factors
50
+ self.cross_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
51
+ self.ln_c1 = nn.LayerNorm(dim)
52
+ self.ffn_c = nn.Sequential(
53
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
54
+ )
55
+ self.ln_c2 = nn.LayerNorm(dim)
56
+
57
+ def forward(self, binder: torch.Tensor, glm: torch.Tensor):
58
+ """
59
+ binder: (batch, Lb, dim)
60
+ glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
61
+ returns: updated binder representation (batch, Lb, dim)
62
+ """
63
+ # binder: self-attn + ffn
64
+ b = binder
65
+ b_sa, _ = self.sa_binder(b, b, b)
66
+ b = self.ln_b1(b + b_sa)
67
+ b_ff = self.ffn_b(b)
68
+ b = self.ln_b2(b + b_ff)
69
+
70
+ # glm: self-attn + ffn
71
+ g = glm
72
+ g_sa, _ = self.sa_glm(g, g, g)
73
+ g = self.ln_g1(g + g_sa)
74
+ g_ff = self.ffn_g(g)
75
+ g = self.ln_g2(g + g_ff)
76
+
77
+ # cross-attention: glm queries binder and glm embeddings are updated
78
+ g_to_b_ca, _ = self.cross_attn(g, b, b)
79
+ g = self.ln_c1(g + g_to_b_ca)
80
+ g_ff = self.ffn_c(g)
81
+ g = self.ln_c2(g + g_ff)
82
+ return g # (batch, Lb, dim)
83
+
84
+
85
+ class DimCompressor(nn.Module):
86
+ """
87
+ Learnable per-token compressor: maps any in_dim >= out_dim to out_dim (default 256).
88
+ If in_dim == out_dim, behaves as identity.
89
+ """
90
+
91
+ def __init__(self, in_dim: int, out_dim: int = 256):
92
+ super().__init__()
93
+ if in_dim == out_dim:
94
+ self.net = nn.Identity()
95
+ else:
96
+ hidden = max(out_dim * 2, (in_dim + out_dim) // 2)
97
+ self.net = nn.Sequential(
98
+ nn.LayerNorm(in_dim),
99
+ nn.Linear(in_dim, hidden),
100
+ nn.GELU(),
101
+ nn.Linear(hidden, out_dim),
102
+ )
103
+
104
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
105
+ # x: (B, L, in_dim)
106
+ return self.net(x)
107
+
108
+
109
+ class BindPredictor(LightningModule):
110
+ def __init__(
111
+ self,
112
+ # input_dim: int = 256, # OLD: single input dim
113
+ binder_input_dim: int = 1280, # NEW: TF (binder) original dim (e.g., 1280)
114
+ glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256)
115
+ compressed_dim: int = 256, # NEW: learnable compressed dim
116
+ hidden_dim: int = 256,
117
+ heads: int = 8,
118
+ num_layers: int = 4,
119
+ lr: float = 1e-4,
120
+ alpha: float = 20,
121
+ gamma: float = 20,
122
+ use_local_cnn_on_glm: bool = True,
123
+ weight_decay: float = 0.01,
124
+ ):
125
+ # Init
126
+ super(BindPredictor, self).__init__()
127
+ self.save_hyperparameters()
128
+
129
+ # Learnable compressor for binder -> 256, then project to hidden
130
+ self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
131
+ self.proj_binder = nn.Linear(compressed_dim, hidden_dim)
132
+
133
+ # GLM side stays 256 -> hidden
134
+ self.proj_glm = nn.Linear(glm_input_dim, hidden_dim)
135
+
136
+ self.use_local_cnn = use_local_cnn_on_glm
137
+ self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
138
+
139
+ self.layers = nn.ModuleList(
140
+ [CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)]
141
+ )
142
+
143
+ self.ln_out = nn.LayerNorm(hidden_dim)
144
+ # self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
145
+ self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
146
+
147
+ def forward(self, binder_emb, glm_emb):
148
+ """
149
+ binder_emb: (B, Lb, binder_input_dim)
150
+ glm_emb: (B, Lg, glm_input_dim)
151
+ Returns per-nucleotide logits for the GLM sequence: (B, Lg)
152
+ """
153
+ # Binder: learnable compression → 256 → hidden
154
+ b = self.binder_compress(binder_emb) # (B, Lb, 256)
155
+ b = self.proj_binder(b) # (B, Lb, hidden_dim)
156
+
157
+ # GLM: project → hidden, add local CNN context
158
+ g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
159
+ if self.use_local_cnn:
160
+ g = self.local_cnn(g)
161
+
162
+ # Cross-modal blocks: update binder states using GLM
163
+ for layer in self.layers:
164
+ g = layer(b, g) # (B, Lb, hidden_dim)
165
+
166
+ # Predict per-nucleotide logits on the GLM tokens:
167
+ # return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head)
168
+ return self.head(g).squeeze(
169
+ -1
170
+ ) # NEW: logits (apply sigmoid only in loss/metrics)
171
+
172
+ # ----- Lightning hooks -----
173
+ def training_step(self, batch, batch_idx):
174
+ """
175
+ Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator.
176
+ Colator returns a dictionary with:
177
+ "binder_emb" # [B, Lb_max, Db]
178
+ "binder_mask" # [B, Lb_max]
179
+ "glm_emb" # [B, Lg_max, Dg]
180
+ "glm_mask" # [B, Lg_max]
181
+ "labels" # [B, Lg_max]
182
+ "ID"
183
+ "tr_sequence"
184
+ "dna_sequence"
185
+ }
186
+ """
187
+ logits = self.forward(batch["binder_emb"], batch["glm_emb"])
188
+ loss = calculate_loss(
189
+ logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
190
+ )
191
+ self.log(
192
+ "train/loss",
193
+ loss,
194
+ on_step=True,
195
+ on_epoch=True,
196
+ prog_bar=True,
197
+ batch_size=logits.size(0),
198
+ )
199
+ return loss
200
+
201
+ def validation_step(self, batch, batch_idx):
202
+ logits = self.forward(batch["binder_emb"], batch["glm_emb"])
203
+ loss = calculate_loss(
204
+ logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
205
+ )
206
+ self.log(
207
+ "val/loss",
208
+ loss,
209
+ on_step=False,
210
+ on_epoch=True,
211
+ prog_bar=True,
212
+ batch_size=logits.size(0),
213
+ )
214
+ return loss
215
+
216
+ def test_step(self, batch, batch_idx):
217
+ logits = self.forward(batch["binder_emb"], batch["glm_emb"])
218
+ loss = calculate_loss(
219
+ logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
220
+ )
221
+ self.log(
222
+ "test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
223
+ )
224
+ return loss
225
+
226
+ def on_train_epoch_end(self):
227
+ if False:
228
+ if self.train_auc.compute() is not None:
229
+ self.log("train/auroc", self.train_auc.compute(), prog_bar=True)
230
+ self.train_auc.reset()
231
+
232
+ def on_validation_epoch_end(self):
233
+ if False:
234
+ if self.val_auc.compute() is not None:
235
+ self.log("val/auroc", self.val_auc.compute(), prog_bar=True)
236
+ self.val_auc.reset()
237
+
238
+ def on_test_epoch_end(self):
239
+ if False:
240
+ if self.test_auc.compute() is not None:
241
+ self.log("test/auroc", self.test_auc.compute(), prog_bar=True)
242
+ self.test_auc.reset()
243
+
244
+ def configure_optimizers(self):
245
+ # AdamW + cosine as a sensible default
246
+ opt = torch.optim.AdamW(
247
+ self.parameters(),
248
+ lr=self.hparams.lr,
249
+ weight_decay=self.hparams.weight_decay,
250
+ )
251
+ # Scheduler optional—comment out if you prefer fixed LR
252
+ sch = torch.optim.lr_scheduler.CosineAnnealingLR(
253
+ opt, T_max=max(self.trainer.max_epochs, 1)
254
+ )
255
+ return {
256
+ "optimizer": opt,
257
+ "lr_scheduler": {"scheduler": sch, "interval": "epoch"},
258
+ }
dpacman/classifier/model/clustering_data.py DELETED
@@ -1,383 +0,0 @@
1
- #!/usr/bin/env python3
2
- import argparse
3
- import numpy as np
4
- import pandas as pd
5
- from pathlib import Path
6
- import random
7
- import sys
8
- import subprocess
9
- from collections import defaultdict
10
-
11
- # ─────────────────────────────────────────────────────────────────────────
12
- # Original helpers (kept; some lightly edited/commented where needed)
13
- # ─────────────────────────────────────────────────────────────────────────
14
-
15
- def read_ids_file(p):
16
- p = Path(p)
17
- if not p.exists():
18
- raise FileNotFoundError(f"IDs file not found: {p}")
19
- return [line.strip() for line in p.open() if line.strip()]
20
-
21
- def split_embeddings(emb_path, ids_path, out_dir, prefix):
22
- out_dir = Path(out_dir)
23
- out_dir.mkdir(parents=True, exist_ok=True)
24
-
25
- if not Path(emb_path).exists():
26
- raise FileNotFoundError(f"Embedding file not found: {emb_path}")
27
- if not Path(ids_path).exists():
28
- raise FileNotFoundError(f"IDs file not found: {ids_path}")
29
-
30
- if emb_path.endswith(".npz"):
31
- data = np.load(emb_path, allow_pickle=True)
32
- if "embeddings" in data:
33
- emb = data["embeddings"]
34
- else:
35
- raise ValueError(f"{emb_path} missing 'embeddings' key")
36
- else:
37
- emb = np.load(emb_path)
38
-
39
- ids = read_ids_file(ids_path)
40
- if len(ids) != emb.shape[0]:
41
- print(f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", file=sys.stderr)
42
-
43
- mapping = {}
44
- for i, ident in enumerate(ids):
45
- if i >= emb.shape[0]:
46
- print(f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr)
47
- continue
48
- arr = emb[i]
49
- out_file = out_dir / f"{prefix}_{ident}.npy"
50
- np.save(out_file, arr)
51
- mapping[ident] = str(out_file)
52
- return mapping
53
-
54
- def extract_symbol_from_tf_id(full_id: str) -> str:
55
- """
56
- Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
57
- return the gene symbol uppercase (e.g., 'ZBTB5').
58
- """
59
- if "|" in full_id:
60
- try:
61
- # format sp|Accession|SYMBOL_HUMAN
62
- genepart = full_id.split("|")[2]
63
- except IndexError:
64
- genepart = full_id
65
- else:
66
- genepart = full_id
67
- symbol = genepart.split("_")[0]
68
- return symbol.upper()
69
-
70
- def build_tf_symbol_map(tf_map):
71
- """
72
- Build mapping gene_symbol -> list of embedding paths.
73
- """
74
- symbol_map = {}
75
- for full_id, path in tf_map.items():
76
- symbol = extract_symbol_from_tf_id(full_id)
77
- symbol_map.setdefault(symbol, []).append(path)
78
- return symbol_map
79
-
80
- def tf_key_from_path(path: str) -> str:
81
- """
82
- Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
83
- """
84
- stem = Path(path).stem # e.g., tf_sp|O15062|ZBTB5_HUMAN
85
- # remove leading prefix if present (tf_)
86
- if "_" in stem:
87
- _, rest = stem.split("_", 1)
88
- else:
89
- rest = stem
90
- return extract_symbol_from_tf_id(rest)
91
-
92
- def dna_key_from_path(path: str) -> str:
93
- """
94
- Given .../dna_peak42.npy -> 'peak42'
95
- """
96
- stem = Path(path).stem
97
- if "_" in stem:
98
- _, rest = stem.split("_", 1)
99
- else:
100
- rest = stem
101
- return rest
102
-
103
- # ─────────────────────────────────────────────────────────────────────────
104
- # New helpers for MMseqs clustering & cluster-level splitting
105
- # ─────────────────────────────────────────────────────────────────────────
106
-
107
- def write_dna_fasta(df: pd.DataFrame, out_fasta: Path) -> None:
108
- """
109
- Write unique DNA sequences to FASTA using dna_id as header.
110
- Requires df with columns: dna_id, dna_sequence
111
- """
112
- uniq = df[["dna_id", "dna_sequence"]].drop_duplicates()
113
- with open(out_fasta, "w") as f:
114
- for _, row in uniq.iterrows():
115
- did = row["dna_id"]
116
- seq = str(row["dna_sequence"]).upper().replace(" ", "").replace("\n", "")
117
- f.write(f">{did}\n{seq}\n")
118
-
119
- def run_mmseqs_easy_cluster(
120
- mmseqs_bin: str,
121
- fasta: Path,
122
- out_prefix: Path,
123
- tmp_dir: Path,
124
- min_seq_id: float,
125
- coverage: float,
126
- cov_mode: int,
127
- ) -> Path:
128
- """
129
- Runs mmseqs easy-cluster on nucleotide sequences.
130
- Returns the path to a clusters TSV file (creating it if the default one isn't present).
131
- """
132
- tmp_dir.mkdir(parents=True, exist_ok=True)
133
- out_prefix.parent.mkdir(parents=True, exist_ok=True)
134
-
135
- cmd = [
136
- mmseqs_bin, "easy-cluster",
137
- str(fasta), str(out_prefix), str(tmp_dir),
138
- "--min-seq-id", str(min_seq_id),
139
- "-c", str(coverage),
140
- "--cov-mode", str(cov_mode),
141
- # You can add performance flags here if needed, e.g.:
142
- # "--threads", "8"
143
- ]
144
- print("[i] Running:", " ".join(cmd), flush=True)
145
- subprocess.run(cmd, check=True)
146
-
147
- # MMseqs easy-cluster typically writes <out_prefix>_cluster.tsv
148
- default_tsv = Path(str(out_prefix) + "_cluster.tsv")
149
- if default_tsv.exists():
150
- print(f"[i] Found cluster TSV: {default_tsv}")
151
- return default_tsv
152
-
153
- # Fallback: try createtsv if default is missing
154
- # This requires the internal DBs. easy-cluster creates DBs alongside out_prefix.
155
- # We'll try to locate them and emit a TSV.
156
- in_db = Path(str(out_prefix) + "_query")
157
- cl_db = Path(str(out_prefix) + "_cluster")
158
- out_tsv = Path(str(out_prefix) + "_fallback_cluster.tsv")
159
- if in_db.exists() and cl_db.exists():
160
- cmd2 = [mmseqs_bin, "createtsv", str(in_db), str(in_db), str(cl_db), str(out_tsv)]
161
- print("[i] Creating TSV via createtsv:", " ".join(cmd2), flush=True)
162
- subprocess.run(cmd2, check=True)
163
- if out_tsv.exists():
164
- return out_tsv
165
-
166
- raise FileNotFoundError("Could not locate clusters TSV from mmseqs. "
167
- "Expected {default_tsv} or createtsv fallback.")
168
-
169
- def parse_mmseqs_clusters(tsv_path: Path) -> dict:
170
- """
171
- Parse MMseqs cluster TSV (rep \t member). Returns dna_id -> cluster_rep_id
172
- """
173
- mapping = {}
174
- with open(tsv_path) as f:
175
- for line in f:
176
- parts = line.rstrip("\n").split("\t")
177
- if len(parts) < 2:
178
- continue
179
- rep, member = parts[0], parts[1]
180
- mapping[member] = rep
181
- # Some TSVs include rep->rep; if not, ensure rep is mapped to itself:
182
- if rep not in mapping:
183
- mapping[rep] = rep
184
- return mapping
185
-
186
- def assign_clusters_to_splits(cluster_rep_to_members: dict,
187
- val_frac: float,
188
- test_frac: float,
189
- seed: int = 42):
190
- """
191
- cluster_rep_to_members: dict[rep] = [members...]
192
- Returns: dict with keys 'train','val','test' mapping to sets of dna_id.
193
- Ensures all members of a cluster go to the same split.
194
- """
195
- rng = random.Random(seed)
196
- reps = list(cluster_rep_to_members.keys())
197
- rng.shuffle(reps)
198
-
199
- # Greedy-ish fill by total member counts to match desired fractions.
200
- total = sum(len(cluster_rep_to_members[r]) for r in reps)
201
- target_val = int(round(total * val_frac))
202
- target_test = int(round(total * test_frac))
203
- cur_val = cur_test = 0
204
-
205
- val_ids, test_ids, train_ids = set(), set(), set()
206
- for rep in reps:
207
- members = cluster_rep_to_members[rep]
208
- c = len(members)
209
- # Fill val first, then test, then train
210
- if cur_val + c <= target_val:
211
- val_ids.update(members); cur_val += c
212
- elif cur_test + c <= target_test:
213
- test_ids.update(members); cur_test += c
214
- else:
215
- train_ids.update(members)
216
-
217
- return {"train": train_ids, "val": val_ids, "test": test_ids}
218
-
219
- # ─────────────────────────────────────────────────────────────────────────
220
- # Main
221
- # ─────────────────────────────────────────────────────────────────────────
222
-
223
- def main():
224
- parser = argparse.ArgumentParser(
225
- description="Build TF-DNA pair lists with MMseqs clustering on DNA to prevent split leakage."
226
- )
227
- parser.add_argument("--final_csv", required=True, help="final.csv with TF_id and dna_sequence")
228
- parser.add_argument("--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)")
229
- parser.add_argument("--dna_ids", required=True, help="IDs file for DNA embeddings (peak*.ids)")
230
- parser.add_argument("--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)")
231
- parser.add_argument("--tf_ids", required=True, help="IDs file for TF embeddings (sp|... ids)")
232
- parser.add_argument("--out_dir", required=True, help="Output directory")
233
- parser.add_argument("--seed", type=int, default=42)
234
-
235
- # NEW: MMseqs options & split fractions
236
- parser.add_argument("--mmseqs_bin", default="mmseqs", help="Path to mmseqs binary")
237
- parser.add_argument("--min_seq_id", type=float, default=0.9, help="MMseqs --min-seq-id")
238
- parser.add_argument("--cov", type=float, default=0.8, help="MMseqs -c coverage fraction")
239
- parser.add_argument("--cov_mode", type=int, default=1, help="MMseqs --cov-mode (1 = coverage of target)")
240
- parser.add_argument("--val_frac", type=float, default=0.10)
241
- parser.add_argument("--test_frac", type=float, default=0.10)
242
- parser.add_argument("--tmp_dir", default=None, help="MMseqs tmp dir (defaults to out_dir/tmp)")
243
- args = parser.parse_args()
244
-
245
- random.seed(args.seed)
246
- out_dir = Path(args.out_dir)
247
- out_dir.mkdir(parents=True, exist_ok=True)
248
-
249
- # Load final.csv
250
- df = pd.read_csv(args.final_csv, dtype=str)
251
- if "TF_id" not in df.columns or "dna_sequence" not in df.columns:
252
- raise RuntimeError("final.csv must have columns TF_id and dna_sequence")
253
-
254
- # Assign dna_id (unique per dna_sequence)
255
- unique_seqs = df["dna_sequence"].drop_duplicates().tolist()
256
- seq_to_id = {seq: f"peak{i}" for i, seq in enumerate(unique_seqs)}
257
- df["dna_id"] = df["dna_sequence"].map(seq_to_id)
258
- enriched_csv = out_dir / "final_with_dna_id.csv"
259
- df.to_csv(enriched_csv, index=False)
260
- print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
261
-
262
- # Split embeddings into per-item files (unchanged)
263
- print(f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}")
264
- dna_map = split_embeddings(args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna")
265
- print(f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})")
266
- print(f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}")
267
- tf_map = split_embeddings(args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf")
268
- print(f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})")
269
-
270
- # Build gene-symbol normalized map
271
- tf_symbol_map = build_tf_symbol_map(tf_map)
272
- print(f"[i] TF symbol map keys (sample): {list(tf_symbol_map.keys())[:30]}")
273
-
274
- # Diagnostic overlaps
275
- norm_tf_in_final = set(t.split("_seq")[0].upper() for t in df["TF_id"].unique())
276
- available_tf_symbols = set(tf_symbol_map.keys())
277
- intersect_tf = norm_tf_in_final & available_tf_symbols
278
- print(f"[i] Unique normalized TF symbols in final.csv: {len(norm_tf_in_final)}")
279
- print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
280
- print(f"[i] Intersection count: {len(intersect_tf)}")
281
- if len(intersect_tf) == 0:
282
- print("[ERROR] No overlap between normalized TF_id and TF embedding symbols.", file=sys.stderr)
283
- print("Sample normalized TFs from final.csv:", sorted(list(norm_tf_in_final))[:30], file=sys.stderr)
284
- print("Sample available TF symbols:", sorted(list(available_tf_symbols))[:30], file=sys.stderr)
285
- sys.exit(1)
286
-
287
- dna_ids_final = set(df["dna_id"].unique())
288
- available_dna_ids = set(dna_map.keys())
289
- intersect_dna = dna_ids_final & available_dna_ids
290
- print(f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}")
291
- if len(intersect_dna) == 0:
292
- print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
293
- sys.exit(1)
294
-
295
- # ── NEW: MMseqs clustering on DNA sequences ───────────────────────────
296
- fasta_path = out_dir / "dna_unique.fasta"
297
- write_dna_fasta(df, fasta_path)
298
- print(f"[i] Wrote FASTA with {df['dna_id'].nunique()} unique sequences → {fasta_path}")
299
-
300
- tmp_dir = Path(args.tmp_dir) if args.tmp_dir else (out_dir / "mmseqs_tmp")
301
- cluster_prefix = out_dir / "mmseqs_dna_clusters"
302
- clusters_tsv = run_mmseqs_easy_cluster(
303
- mmseqs_bin=args.mmseqs_bin,
304
- fasta=fasta_path,
305
- out_prefix=cluster_prefix,
306
- tmp_dir=tmp_dir,
307
- min_seq_id=args.min_seq_id,
308
- coverage=args.cov,
309
- cov_mode=args.cov_mode,
310
- )
311
-
312
- # Parse clusters
313
- member_to_rep = parse_mmseqs_clusters(clusters_tsv) # dna_id -> rep_id
314
- # Build rep -> members list
315
- rep_to_members = defaultdict(list)
316
- for member, rep in member_to_rep.items():
317
- rep_to_members[rep].append(member)
318
-
319
- print(f"[i] Parsed {len(rep_to_members)} clusters from {clusters_tsv}")
320
- clusters_table = []
321
- for rep, members in rep_to_members.items():
322
- for m in members:
323
- clusters_table.append((m, rep))
324
- clusters_df = pd.DataFrame(clusters_table, columns=["dna_id", "cluster_id"])
325
- clusters_df.to_csv(out_dir / "clusters.tsv", sep="\t", index=False)
326
- print(f"[i] Wrote clusters mapping → {out_dir / 'clusters.tsv'}")
327
-
328
- # Attach cluster_id back to final df
329
- df = df.merge(clusters_df, on="dna_id", how="left")
330
- df.to_csv(out_dir / "final_with_dna_id_and_cluster.csv", index=False)
331
- print(f"[i] Wrote {out_dir / 'final_with_dna_id_and_cluster.csv'}")
332
-
333
- # Assign entire clusters to splits
334
- splits = assign_clusters_to_splits(rep_to_members,
335
- val_frac=args.val_frac,
336
- test_frac=args.test_frac,
337
- seed=args.seed)
338
- for k in ["train", "val", "test"]:
339
- print(f"[i] {k}: {len(splits[k])} dna_ids")
340
-
341
- # ── Build positive pairs only, per split (NO negatives) ───────────────
342
- positives_by_split = {"train": [], "val": [], "test": []}
343
- # Build a quick dna_id -> embedding path map
344
- dnaid_to_path = {did: path for did, path in dna_map.items()}
345
-
346
- pos_count = 0
347
- for _, row in df.iterrows():
348
- tf_raw = row["TF_id"]
349
- tf_symbol = tf_raw.split("_seq")[0].upper()
350
- dnaid = row["dna_id"]
351
- if (tf_symbol not in tf_symbol_map) or (dnaid not in dnaid_to_path):
352
- continue
353
- tf_embedding_path = tf_symbol_map[tf_symbol][0] # first embedding per symbol
354
-
355
- # decide split by dna_id cluster assignment
356
- if dnaid in splits["train"]:
357
- positives_by_split["train"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
358
- elif dnaid in splits["val"]:
359
- positives_by_split["val"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
360
- elif dnaid in splits["test"]:
361
- positives_by_split["test"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
362
- pos_count += 1
363
-
364
- print(f"[i] Constructed positives across splits (rows in final.csv iterated: {len(df)})")
365
- for k in ["train", "val", "test"]:
366
- print(f"[i] positives[{k}] = {len(positives_by_split[k])}")
367
-
368
- # # OLD: negatives (kept commented)
369
- # negatives = []
370
- # print(f"[i] Sampled {len(negatives)} negatives (neg_per_positive not used)")
371
-
372
- # Emit split-specific pair lists
373
- for split in ["train", "val", "test"]:
374
- out_tsv = out_dir / f"pair_list_{split}.tsv"
375
- with open(out_tsv, "w") as f:
376
- for binder_path, glm_path, label in positives_by_split[split]: # + negatives if you add later
377
- f.write(f"{binder_path}\t{glm_path}\t{label}\n")
378
- print(f"[i] Wrote {len(positives_by_split[split])} examples to {out_tsv}")
379
-
380
- print("✅ Done. Cluster-aware splits ready.")
381
-
382
- if __name__ == "__main__":
383
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/classifier/model/compress_embeddings.py DELETED
@@ -1,54 +0,0 @@
1
- # compress_embeddings.py
2
- # USAGE: python compress_embeddings.py --input_glob "/path/to/esm_embeddings/*.npy" --output_dir "/path/to/compressed_embeddings" --esm_dim 1280 --out_dim 256
3
- # --------------
4
- import os
5
- import glob
6
- import numpy as np
7
- import torch
8
- from torch import nn
9
-
10
- class EmbeddingCompressor(nn.Module):
11
- def __init__(self, input_dim: int = 1280, output_dim: int = 256):
12
- super().__init__()
13
- self.fc = nn.Linear(input_dim, output_dim)
14
-
15
- def forward(self, x: torch.Tensor) -> torch.Tensor:
16
- """
17
- x: (batch, L, input_dim) or (L, input_dim)
18
- returns: (batch, output_dim) or (output_dim,)
19
- """
20
- if x.dim() == 2:
21
- # single example: mean over tokens
22
- x = x.mean(dim=0, keepdim=True) # → (1, input_dim)
23
- else:
24
- # batch: mean over tokens
25
- x = x.mean(dim=1) # → (batch, input_dim)
26
- return self.fc(x) # → (batch, output_dim)
27
-
28
- def compress_file(in_path: str, out_path: str, model: EmbeddingCompressor):
29
- arr = np.load(in_path) # shape (L, D) or (batch, L, D)
30
- tensor = torch.from_numpy(arr).float()
31
- with torch.no_grad():
32
- compressed = model(tensor) # → (batch, 256)
33
- out = compressed.cpu().numpy()
34
- np.save(out_path, out)
35
- print(f"Saved {out_path}")
36
-
37
- if __name__ == "__main__":
38
- import argparse
39
- parser = argparse.ArgumentParser(description="Compress ESM embeddings to 256­d")
40
- parser.add_argument("--input_glob", type=str, required=True,
41
- help="Glob for your .npy ESM embeddings (e.g. data/esm_*.npy)")
42
- parser.add_argument("--output_dir", type=str, required=True)
43
- parser.add_argument("--esm_dim", type=int, default=1280)
44
- parser.add_argument("--out_dim", type=int, default=256)
45
- args = parser.parse_args()
46
-
47
- os.makedirs(args.output_dir, exist_ok=True)
48
- compressor = EmbeddingCompressor(args.esm_dim, args.out_dim)
49
- compressor.eval()
50
-
51
- for fn in glob.glob(args.input_glob):
52
- base = os.path.basename(fn).replace(".npy", "_256.npy")
53
- out_path = os.path.join(args.output_dir, base)
54
- compress_file(fn, out_path, compressor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/classifier/model/compute_embeddings.py DELETED
@@ -1,560 +0,0 @@
1
- """
2
- Plug-and-play embedding extraction for:
3
- • Chromosome sequences (from raw UCSC JSON)
4
- • TF sequences (transcription_factors.fasta)
5
-
6
- Usage example (DNA + protein in one go):
7
- module load miniconda/24.7.1
8
- conda activate dpacman
9
- python dpacman/data/compute_embeddings.py \
10
- --genome-json-dir ../data_files/raw/genomes/hg38 \
11
- --tf-fasta ../data_files/processed/tfclust/hg38_tf/transcription_factors.fasta \
12
- --chrom-model caduceus \
13
- --tf-model esm-dbp \
14
- --out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \
15
- --device cuda
16
- """
17
- import os
18
- import re
19
- import argparse
20
- import json
21
- import numpy as np
22
- from pathlib import Path
23
- import torch
24
- from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, pipeline
25
- import esm
26
- from Bio import SeqIO
27
- import time
28
- import pandas as pd
29
- from tqdm.auto import tqdm
30
- import logging, math
31
-
32
- # ---- model wrappers ----
33
-
34
- class CaduceusEmbedder:
35
- def __init__(self, device, chunk_size=131_072, overlap=0):
36
- """
37
- device: 'cpu' or 'cuda'
38
- chunk_size: max bases (and thus tokens) to send in one forward pass
39
- overlap: how many bases each window overlaps the previous; 0 = no overlap
40
- """
41
- model_name = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"
42
- self.tokenizer = AutoTokenizer.from_pretrained(
43
- model_name, trust_remote_code=True
44
- )
45
- self.model = AutoModel.from_pretrained(
46
- model_name, trust_remote_code=True
47
- ).to(device).eval()
48
- self.device = device
49
- self.chunk_size = chunk_size
50
- self.step = chunk_size - overlap
51
-
52
- def embed(self, seqs):
53
- """
54
- seqs: List[str] of DNA sequences (each <= chunk_size for this test)
55
- returns: np.ndarray of shape (N, L, D), raw per‐token embeddings
56
- """
57
- # outputs = []
58
- # for seq in seqs:
59
- # # --- new: raw per‐token embeddings in one shot ---
60
- # toks = self.tokenizer(
61
- # seq,
62
- # return_tensors="pt",
63
- # padding=False,
64
- # truncation=True,
65
- # max_length=self.chunk_size
66
- # ).to(self.device)
67
- # with torch.no_grad():
68
- # out = self.model(**toks).last_hidden_state # (1, L, D)
69
- # outputs.append(out.cpu().numpy()[0]) # (L, D)
70
-
71
- # return np.stack(outputs, axis=0) # (N, L, D)
72
- outputs = []
73
- for seq in tqdm(seqs, total=len(seqs), desc="DNA: Caduceus", dynamic_ncols=True):
74
- toks = self.tokenizer(
75
- seq,
76
- return_tensors="pt",
77
- padding=False,
78
- truncation=True,
79
- max_length=self.chunk_size
80
- ).to(self.device)
81
- with torch.no_grad():
82
- out = self.model(**toks).last_hidden_state # (1, L, D)
83
- outputs.append(out.cpu().numpy()[0]) # (L, D)
84
- return outputs # list of variable-length (L_i, D) arrays
85
-
86
-
87
- def benchmark(self, lengths=None):
88
- """
89
- Time embedding on single-sequence of various lengths.
90
- By default tests [5K,10K,50K,100K,chunk_size].
91
- """
92
- tests = lengths or [5_000, 10_000, 50_000, 100_000, self.chunk_size]
93
- print(f"→ Benchmarking Caduceus on device={self.device}")
94
- for sz in tests:
95
- seq = "A" * sz
96
- # Warm-up
97
- _ = self.embed([seq])
98
- if self.device != "cpu":
99
- torch.cuda.synchronize()
100
- t0 = time.perf_counter()
101
- _ = self.embed([seq])
102
- if self.device != "cpu":
103
- torch.cuda.synchronize()
104
- t1 = time.perf_counter()
105
- print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms")
106
-
107
- class SegmentNTEmbedder:
108
- def __init__(self, device):
109
- self.tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True)
110
- self.model = AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True).to(device).eval()
111
- self.device = device
112
-
113
- def _adjust_length(self, input_ids):
114
- bs, L = input_ids.shape
115
- excl = L - 1
116
- remainder = (excl) % 4
117
- if remainder != 0:
118
- pad_needed = 4 - remainder
119
- pad_tensor = torch.full((bs, pad_needed), self.tokenizer.pad_token_id, dtype=input_ids.dtype, device=input_ids.device)
120
- input_ids = torch.cat([input_ids, pad_tensor], dim=1)
121
- return input_ids
122
-
123
- def embed(self, seqs, batch_size=16):
124
- """
125
- seqs: List[str]
126
- Returns: np.ndarray of shape (N, D)
127
- """
128
- all_embeddings = []
129
- for i in range(0, len(seqs), batch_size):
130
- batch_seqs = seqs[i : i + batch_size]
131
- encoded = self.tokenizer.batch_encode_plus(
132
- batch_seqs,
133
- return_tensors="pt",
134
- padding=True,
135
- truncation=True,
136
- )
137
- input_ids = encoded["input_ids"].to(self.device) # (B, L)
138
- attention_mask = input_ids != self.tokenizer.pad_token_id
139
-
140
- input_ids = self._adjust_length(input_ids)
141
- attention_mask = (input_ids != self.tokenizer.pad_token_id)
142
-
143
- with torch.no_grad():
144
- outs = self.model(
145
- input_ids,
146
- attention_mask=attention_mask,
147
- output_hidden_states=True,
148
- return_dict=True,
149
- )
150
- if hasattr(outs, "hidden_states") and outs.hidden_states is not None:
151
- last_hidden = outs.hidden_states[-1] # (B, L, D)
152
- else:
153
- last_hidden = outs.last_hidden_state # fallback
154
-
155
- # Exclude CLS token if present (assume first token) and pool
156
- pooled = last_hidden[:, 1:, :].mean(dim=1) # (B, D)
157
- all_embeddings.append(pooled.cpu().numpy())
158
-
159
- # release fragmentation
160
- torch.cuda.empty_cache()
161
-
162
- return np.vstack(all_embeddings) # (N, D)
163
-
164
-
165
- class DNABertEmbedder:
166
- def __init__(self, device):
167
- self.tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True)
168
- self.model = AutoModel.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True).to(device)
169
- self.device = device
170
-
171
- def embed(self, seqs):
172
- embs = []
173
- for s in seqs:
174
- tokens = self.tokenizer(s, return_tensors="pt", padding=True)["input_ids"].to(self.device)
175
- with torch.no_grad():
176
- out = self.model(tokens).last_hidden_state.mean(1)
177
- embs.append(out.cpu().numpy())
178
- return np.vstack(embs)
179
-
180
- class NucleotideTransformerEmbedder:
181
- def __init__(self, device):
182
- # HF “feature-extraction” returns a list of (L, D) arrays for each input
183
- # device: “cpu” or “cuda”
184
- self.pipe = pipeline(
185
- "feature-extraction",
186
- model="InstaDeepAI/nucleotide-transformer-500m-1000g",
187
- device= -1 if device=="cpu" else 0 # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0}
188
- )
189
-
190
- def embed(self, seqs):
191
- """
192
- seqs: List[str] of raw DNA sequences
193
- returns: (N, D) array, one D-dim vector per sequence
194
- """
195
- all_embeddings = self.pipe(seqs, truncation=True, padding=True)
196
- # all_embeddings is a List of shape (L, D) arrays
197
- pooled = [ np.mean(x, axis=0) for x in all_embeddings ]
198
- return np.vstack(pooled)
199
-
200
- # class ESMEmbedder:
201
- # def __init__(self, device):
202
- # self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
203
- # self.batch_converter = self.alphabet.get_batch_converter()
204
- # self.model.to(device).eval()
205
- # self.device = device
206
-
207
- # def embed(self, seqs):
208
- # batch = [(str(i), seq) for i, seq in enumerate(seqs)]
209
- # _, _, toks = self.batch_converter(batch)
210
- # toks = toks.to(self.device)
211
- # with torch.no_grad():
212
- # results = self.model(toks, repr_layers=[33], return_contacts=False)
213
- # reps = results["representations"][33]
214
- # return reps[:, 1:-1].mean(1).cpu().numpy()
215
-
216
-
217
- class ESMEmbedder:
218
- def __init__(self, device, model_name="esm2_t33_650M_UR50D"):
219
- # Try to load the specified ESM-2 model; fallback to esm1b if missing
220
- self.device = device
221
- try:
222
- self.model, self.alphabet = getattr(esm.pretrained, model_name)()
223
- self.is_esm2 = model_name.lower().startswith("esm2")
224
- except AttributeError:
225
- # fallback to ESM-1b
226
- self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
227
- self.is_esm2 = False
228
- self.batch_converter = self.alphabet.get_batch_converter()
229
- self.model.to(device).eval()
230
- # determine max length: esm2 models vary; use default 1024 for esm1b
231
- self.max_len = 4096 if self.is_esm2 else 1024 # adjust if your esm2 variant has explicit limit
232
- # for chunking: reserve 2 tokens if model uses BOS/EOS
233
- self.chunk_size = self.max_len - 2
234
- self.overlap = self.chunk_size // 4 # 25% overlap to smooth boundaries
235
-
236
- def _chunk_sequence(self, seq):
237
- """
238
- Return list of possibly overlapping chunks of seq, each <= chunk_size.
239
- """
240
- if len(seq) <= self.chunk_size:
241
- return [seq]
242
- step = self.chunk_size - self.overlap
243
- chunks = []
244
- for i in range(0, len(seq), step):
245
- chunk = seq[i : i + self.chunk_size]
246
- if not chunk:
247
- break
248
- chunks.append(chunk)
249
- return chunks
250
-
251
- def embed(self, seqs):
252
- """
253
- seqs: List[str] of protein sequences.
254
- Returns: np.ndarray of shape (N, D) pooled per-sequence embeddings.
255
- """
256
- all_embeddings = []
257
- for i, seq in enumerate(seqs):
258
- chunks = self._chunk_sequence(seq)
259
- chunk_vecs = []
260
- # process chunks in batch if small number, else sequentially
261
- for chunk in chunks:
262
- batch = [(str(i), chunk)]
263
- _, _, toks = self.batch_converter(batch)
264
- toks = toks.to(self.device)
265
- with torch.no_grad():
266
- results = self.model(toks, repr_layers=[33], return_contacts=False)
267
- reps = results["representations"][33] # (1, L, D)
268
- # remove BOS/EOS if present: take 1:-1 if length permits
269
- if reps.size(1) > 2:
270
- rep = reps[:, 1:-1].mean(1) # (1, D)
271
- else:
272
- rep = reps.mean(1) # fallback
273
- chunk_vecs.append(rep.squeeze(0)) # (D,)
274
- if len(chunk_vecs) == 1:
275
- seq_vec = chunk_vecs[0]
276
- else:
277
- # average chunk vectors
278
- stacked = torch.stack(chunk_vecs, dim=0) # (num_chunks, D)
279
- seq_vec = stacked.mean(0)
280
- all_embeddings.append(seq_vec.cpu().numpy())
281
- return np.vstack(all_embeddings) # (N, D)
282
-
283
-
284
- # class ESMDBPEmbedder:
285
- # def __init__(self, device):
286
- # base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
287
- # model_path = (
288
- # Path(__file__).resolve().parent.parent
289
- # / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
290
- # )
291
- # checkpoint = torch.load(model_path, map_location="cpu")
292
- # clean_sd = {}
293
- # for k, v in checkpoint.items():
294
- # clean_sd[k.replace("module.", "")] = v
295
- # result = base_model.load_state_dict(clean_sd, strict=False)
296
- # if result.missing_keys:
297
- # print(f"[ESMDBP] missing keys: {result.missing_keys}")
298
- # if result.unexpected_keys:
299
- # print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}")
300
-
301
- # self.model = base_model.to(device).eval()
302
- # self.alphabet = alphabet
303
- # self.batch_converter = alphabet.get_batch_converter()
304
- # self.device = device
305
-
306
- # def embed(self, seqs):
307
- # batch = [(str(i), seq) for i, seq in enumerate(seqs)]
308
- # _, _, toks = self.batch_converter(batch)
309
- # toks = toks.to(self.device)
310
- # with torch.no_grad():
311
- # out = self.model(toks, repr_layers=[33], return_contacts=False)
312
- # reps = out["representations"][33]
313
- # # skip start/end tokens
314
- # return reps[:, 1:-1].mean(1).cpu().numpy()
315
-
316
- class ESMDBPEmbedder:
317
- def __init__(self, device):
318
- base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
319
- model_path = (
320
- Path(__file__).resolve().parent.parent
321
- / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
322
- )
323
- checkpoint = torch.load(model_path, map_location="cpu")
324
- clean_sd = {}
325
- for k, v in checkpoint.items():
326
- clean_sd[k.replace("module.", "")] = v
327
- result = base_model.load_state_dict(clean_sd, strict=False)
328
- if result.missing_keys:
329
- print(f"[ESMDBP] missing keys: {result.missing_keys}")
330
- if result.unexpected_keys:
331
- print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}")
332
-
333
- self.model = base_model.to(device).eval()
334
- self.alphabet = alphabet
335
- self.batch_converter = alphabet.get_batch_converter()
336
- self.device = device
337
- self.max_len = 1024 # same limit as esm1b
338
- self.chunk_size = self.max_len - 2
339
- self.overlap = self.chunk_size // 4
340
-
341
- def _chunk_sequence(self, seq):
342
- if len(seq) <= self.chunk_size:
343
- return [seq]
344
- step = self.chunk_size - self.overlap
345
- chunks = []
346
- for i in range(0, len(seq), step):
347
- chunk = seq[i : i + self.chunk_size]
348
- if not chunk:
349
- break
350
- chunks.append(chunk)
351
- return chunks
352
-
353
- def embed(self, seqs):
354
- all_embeddings = []
355
- for i, seq in enumerate(seqs):
356
- chunks = self._chunk_sequence(seq)
357
- chunk_vecs = []
358
- for chunk in chunks:
359
- batch = [(str(i), chunk)]
360
- _, _, toks = self.batch_converter(batch)
361
- toks = toks.to(self.device)
362
- with torch.no_grad():
363
- out = self.model(toks, repr_layers=[33], return_contacts=False)
364
- reps = out["representations"][33]
365
- if reps.size(1) > 2:
366
- rep = reps[:, 1:-1].mean(1)
367
- else:
368
- rep = reps.mean(1)
369
- chunk_vecs.append(rep.squeeze(0))
370
- if len(chunk_vecs) == 1:
371
- seq_vec = chunk_vecs[0]
372
- else:
373
- stacked = torch.stack(chunk_vecs, dim=0)
374
- seq_vec = stacked.mean(0)
375
- all_embeddings.append(seq_vec.cpu().numpy())
376
- return np.vstack(all_embeddings)
377
-
378
- class GPNEmbedder:
379
- def __init__(self, device):
380
- model_name = "songlab/gpn-msa-sapiens"
381
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
382
- self.model = AutoModelForMaskedLM.from_pretrained(model_name)
383
- self.model.to(device)
384
- self.model.eval()
385
- self.device = device
386
-
387
- def embed(self, seqs):
388
- inputs = self.tokenizer(
389
- seqs,
390
- return_tensors="pt",
391
- padding=True,
392
- truncation=True
393
- ).to(self.device)
394
-
395
- with torch.no_grad():
396
- last_hidden = self.model(**inputs).last_hidden_state
397
- return last_hidden.mean(dim=1).cpu().numpy()
398
-
399
- class ProGenEmbedder:
400
- def __init__(self, device):
401
- model_name = "jinyuan22/ProGen2-base"
402
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
403
- self.model = AutoModel.from_pretrained(model_name).to(device).eval()
404
- self.device = device
405
-
406
- def embed(self, seqs):
407
- inputs = self.tokenizer(
408
- seqs,
409
- return_tensors="pt",
410
- padding=True,
411
- truncation=True
412
- ).to(self.device)
413
- with torch.no_grad():
414
- last_hidden = self.model(**inputs).last_hidden_state
415
- return last_hidden.mean(dim=1).cpu().numpy()
416
-
417
- # ---- main pipeline ----
418
-
419
- def get_embedder(name, device, for_dna=True):
420
- name = name.lower()
421
- if for_dna:
422
- if name=="caduceus": return CaduceusEmbedder(device)
423
- if name=="dnabert": return DNABertEmbedder(device)
424
- if name=="nucleotide": return NucleotideTransformerEmbedder(device)
425
- if name=="gpn": return GPNEmbedder(device)
426
- if name=="segmentnt": return SegmentNTEmbedder(device)
427
- else:
428
- if name in ("esm",): return ESMEmbedder(device)
429
- if name in ("esm-dbp","esm_dbp"): return ESMDBPEmbedder(device)
430
- if name=="progen": return ProGenEmbedder(device)
431
- raise ValueError(f"Unknown model {name} (for_dna={for_dna})")
432
-
433
-
434
- def pad_token_embeddings(list_of_arrays, pad_value=0.0):
435
- """
436
- list_of_arrays: list of (L_i, D) numpy arrays
437
- Returns:
438
- padded: (N, L_max, D) array
439
- mask: (N, L_max) boolean array where True = real token, False = padding
440
- """
441
- N = len(list_of_arrays)
442
- D = list_of_arrays[0].shape[1]
443
- L_max = max(arr.shape[0] for arr in list_of_arrays)
444
- padded = np.full((N, L_max, D), pad_value, dtype=list_of_arrays[0].dtype)
445
- mask = np.zeros((N, L_max), dtype=bool)
446
- for i, arr in enumerate(list_of_arrays):
447
- L = arr.shape[0]
448
- padded[i, :L] = arr
449
- mask[i, :L] = True
450
- return padded, mask
451
-
452
- def embed_and_save(seqs, ids, embedder, out_path):
453
- embs = embedder.embed(seqs)
454
-
455
- # Decide whether we got variable-length per-token outputs (list of (L, D))
456
- is_variable_token = isinstance(embs, (list, tuple)) and len(embs) > 0 and hasattr(embs[0], "shape") and embs[0].ndim == 2
457
-
458
- if is_variable_token:
459
- # pad to (N, L_max, D) + mask
460
- padded, mask = pad_token_embeddings(embs)
461
- # Save both embeddings and mask together in an .npz for convenience
462
- np.savez_compressed(out_path.with_suffix(".caduceus.npz"),
463
- embeddings=padded,
464
- mask=mask,
465
- ids=np.array(ids, dtype=object))
466
- else:
467
- # fixed shape output, e.g., pooled (N, D)
468
- array = np.vstack(embs) if isinstance(embs, list) else embs
469
- np.save(out_path, array)
470
- with open(out_path.with_suffix(".ids"), "w") as f:
471
- f.write("\n".join(ids))
472
-
473
-
474
- if __name__=="__main__":
475
-
476
- p = argparse.ArgumentParser()
477
- #p.add_argument("--peak_fasta", default="binding_peaks_unique.fa", help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs")
478
- p.add_argument("--genome-json-dir", default=None, help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes")
479
- p.add_argument("--skip-dna", action="store_true", help="if set, skip the chromosome embedding step") #if glm embeddings successful but not plm embeddings
480
- p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
481
- p.add_argument("--chrom-model", default="caduceus")
482
- p.add_argument("--tf-model", default="esm-dbp")
483
- p.add_argument("--out-dir", default="dpacman/model/embeddings")
484
- p.add_argument("--device", default="cpu")
485
- args = p.parse_args()
486
-
487
- os.makedirs(args.out_dir, exist_ok=True)
488
- device = args.device
489
- print(device)
490
-
491
- if not args.skip_dna:
492
- if args.genome_json_dir == None:
493
- dna_df = pd.read_parquet('/home/a03-akrishna/DPACMAN/dpacman/model/remap2022_crm_fimo_output_q_processed.parquet', engine='pyarrow')
494
- #df.to_csv('/home/a03-akrishna/DPACMAN/dpacman/model/remap2022_crm_fimo_output_q_processed.csv', index=False)
495
- peak_seqs = dna_df["dna_sequence"]
496
- peak_ids = dna_df["ID"]
497
- print(f"Embedding {len(peak_seqs)} binding peak sequences from processed remap data", flush=True)
498
- dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
499
- out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
500
- embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
501
-
502
- # peak_fasta = Path(args.peak_fasta)
503
- # if peak_fasta.exists():
504
- # # Load peak sequences from FASTA
505
- # from Bio import SeqIO
506
-
507
- # peak_seqs = []
508
- # peak_ids = []
509
- # for rec in SeqIO.parse(peak_fasta, "fasta"):
510
- # peak_ids.append(rec.id)
511
- # peak_seqs.append(str(rec.seq))
512
- # print(f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}", flush=True)
513
- # dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
514
- # out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
515
- # embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
516
- elif args.genome_json_dir:
517
- # Legacy: load full chromosomes from JSONs (chr1–22, X, Y, M)
518
- genome_dir = Path(args.genome_json_dir)
519
- chrom_seqs, chrom_ids = [], []
520
- primary_pattern = re.compile(r"^hg38_chr(?:[1-9]|1[0-9]|2[0-2]|X|Y|M)\.json$")
521
- for j in sorted(genome_dir.iterdir()):
522
- if not primary_pattern.match(j.name):
523
- continue
524
- data = json.loads(j.read_text())
525
- seq = data.get("dna") or data.get("sequence")
526
- chrom = data.get("chrom") or j.stem.split("_")[-1]
527
- chrom_seqs.append(seq)
528
- chrom_ids.append(chrom)
529
- cutoff = CaduceusEmbedder(device).chunk_size
530
- long_chroms = [
531
- (chrom, len(seq))
532
- for chrom, seq in zip(chrom_ids, chrom_seqs)
533
- if len(seq) > cutoff
534
- ]
535
- if long_chroms:
536
- print("⚠️ Chromosomes exceeding Caduceus max tokens ({}):".format(cutoff))
537
- for chrom, L in long_chroms:
538
- print(f" {chrom}: {L} bases")
539
- else:
540
- print("All chromosomes ≤ Caduceus limit ({}).".format(cutoff))
541
-
542
- chrom_embedder = get_embedder(args.chrom_model, device, for_dna=True)
543
- out_chrom = Path(args.out_dir) / f"chrom_{args.chrom_model}.npy"
544
- embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom)
545
- else:
546
- raise ValueError("No input for DNA embedding: provide a peak FASTA (default binding_peaks_unique.fa) or set --genome-json-dir for chromosome JSONs.")
547
-
548
-
549
- #Load TF sequences
550
- tf_seqs, tf_ids = [], []
551
- for record in SeqIO.parse(args.tf_fasta, "fasta"):
552
- tf_ids.append(record.id)
553
- tf_seqs.append(str(record.seq))
554
-
555
- # embed and save
556
- tf_embedder = get_embedder(args.tf_model, device, for_dna=False)
557
- out_tf = Path(args.out_dir) / f"tf_{args.tf_model}.npy"
558
- embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf)
559
-
560
- print("Done.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/classifier/model/extract_tf_symbols.py DELETED
@@ -1,27 +0,0 @@
1
- #!/usr/bin/env python3
2
- import pandas as pd
3
- from pathlib import Path
4
-
5
- FINAL_CSV = Path("/home/a03-akrishna/DPACMAN/data_files/processed/final.csv")
6
- OUT_SYMBOLS = Path("tf_symbols.txt")
7
-
8
- def normalize_tf(tf_id: str) -> str:
9
- return tf_id.split("_seq")[0].upper()
10
-
11
- def main():
12
- df = pd.read_csv(FINAL_CSV, dtype=str)
13
- if "TF_id" not in df.columns:
14
- raise RuntimeError("final.csv missing TF_id column")
15
- tf_raw = df["TF_id"].dropna().unique().tolist()
16
- normalized = sorted({normalize_tf(t) for t in tf_raw})
17
- print(f"Unique raw TF_id count: {len(tf_raw)}")
18
- print(f"Unique normalized TF symbols: {len(normalized)}")
19
- with open(OUT_SYMBOLS, "w") as f:
20
- for s in normalized:
21
- f.write(s + "\n")
22
- print(f"Wrote normalized TF symbols to {OUT_SYMBOLS}")
23
- # Optional: show sample
24
- print("Sample symbols:", normalized[:50])
25
-
26
- if __name__ == "__main__":
27
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/classifier/model/loss.py DELETED
@@ -1,34 +0,0 @@
1
- """
2
- Define loss functions needed for training the model
3
- """
4
- import torch
5
- from torch.nn import functional as F
6
-
7
- def combined_loss_components(logits, targets, peak_thresh=0.5, eps=1e-8):
8
- probs = torch.sigmoid(logits)
9
- labels = (targets >= peak_thresh).float()
10
- non_peak_mask = (labels == 0).float()
11
- peak_mask = (labels == 1).float()
12
-
13
- bce_all = F.binary_cross_entropy_with_logits(logits, labels, reduction='none')
14
- bce_non = (bce_all * non_peak_mask)
15
- bce_non = bce_non.sum() / (non_peak_mask.sum() + eps)
16
-
17
- mse_peaks = F.mse_loss(probs * peak_mask, targets * peak_mask, reduction='sum') \
18
- / (peak_mask.sum() + eps)
19
-
20
- t_dist = (targets + eps)
21
- p_dist = (probs + eps)
22
- t_dist = t_dist / t_dist.sum(dim=1, keepdim=True)
23
- p_dist = p_dist / p_dist.sum(dim=1, keepdim=True)
24
- kl = (t_dist * (t_dist.clamp(min=eps).log() - p_dist.clamp(min=eps).log())).sum(dim=1).mean()
25
-
26
- return bce_non, kl, mse_peaks, probs
27
-
28
- def accuracy_percentage(logits, targets, peak_thresh=0.5):
29
- probs = torch.sigmoid(logits)
30
- preds_bin = (probs >= 0.5).float()
31
- labels = (targets >= peak_thresh).float()
32
- correct = (preds_bin == labels).float().sum()
33
- total = torch.numel(labels)
34
- return (correct / max(1, total)).item() * 100.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/classifier/model/make_pair_list.py DELETED
@@ -1,220 +0,0 @@
1
- #!/usr/bin/env python3
2
- import argparse
3
- import numpy as np
4
- import pandas as pd
5
- from pathlib import Path
6
- import random
7
- import sys
8
-
9
- def read_ids_file(p):
10
- p = Path(p)
11
- if not p.exists():
12
- raise FileNotFoundError(f"IDs file not found: {p}")
13
- return [line.strip() for line in p.open() if line.strip()]
14
-
15
- def split_embeddings(emb_path, ids_path, out_dir, prefix):
16
- out_dir = Path(out_dir)
17
- out_dir.mkdir(parents=True, exist_ok=True)
18
-
19
- if not Path(emb_path).exists():
20
- raise FileNotFoundError(f"Embedding file not found: {emb_path}")
21
- if not Path(ids_path).exists():
22
- raise FileNotFoundError(f"IDs file not found: {ids_path}")
23
-
24
- if emb_path.endswith(".npz"):
25
- data = np.load(emb_path, allow_pickle=True)
26
- if "embeddings" in data:
27
- emb = data["embeddings"]
28
- else:
29
- raise ValueError(f"{emb_path} missing 'embeddings' key")
30
- else:
31
- emb = np.load(emb_path)
32
-
33
- ids = read_ids_file(ids_path)
34
- if len(ids) != emb.shape[0]:
35
- print(f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", file=sys.stderr)
36
-
37
- mapping = {}
38
- for i, ident in enumerate(ids):
39
- if i >= emb.shape[0]:
40
- print(f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr)
41
- continue
42
- arr = emb[i]
43
- out_file = out_dir / f"{prefix}_{ident}.npy"
44
- np.save(out_file, arr)
45
- mapping[ident] = str(out_file)
46
- return mapping
47
-
48
- def extract_symbol_from_tf_id(full_id: str) -> str:
49
- """
50
- Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
51
- return the gene symbol uppercase (e.g., 'ZBTB5').
52
- """
53
- if "|" in full_id:
54
- try:
55
- # format sp|Accession|SYMBOL_HUMAN
56
- genepart = full_id.split("|")[2]
57
- except IndexError:
58
- genepart = full_id
59
- else:
60
- genepart = full_id
61
- symbol = genepart.split("_")[0]
62
- return symbol.upper()
63
-
64
- def build_tf_symbol_map(tf_map):
65
- """
66
- Build mapping gene_symbol -> list of embedding paths.
67
- """
68
- symbol_map = {}
69
- for full_id, path in tf_map.items():
70
- symbol = extract_symbol_from_tf_id(full_id)
71
- symbol_map.setdefault(symbol, []).append(path)
72
- return symbol_map
73
-
74
- def tf_key_from_path(path: str) -> str:
75
- """
76
- Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
77
- """
78
- stem = Path(path).stem # e.g., tf_sp|O15062|ZBTB5_HUMAN
79
- # remove leading prefix if present (tf_)
80
- if "_" in stem:
81
- _, rest = stem.split("_", 1)
82
- else:
83
- rest = stem
84
- return extract_symbol_from_tf_id(rest)
85
-
86
- def dna_key_from_path(path: str) -> str:
87
- """
88
- Given .../dna_peak42.npy -> 'peak42'
89
- """
90
- stem = Path(path).stem
91
- if "_" in stem:
92
- _, rest = stem.split("_", 1)
93
- else:
94
- rest = stem
95
- return rest
96
-
97
- def main():
98
- parser = argparse.ArgumentParser(
99
- description="Build TF-DNA pair list from final.csv with gene-symbol normalization for TFs."
100
- )
101
- parser.add_argument("--final_csv", required=True, help="final.csv with TF_id and dna_sequence")
102
- parser.add_argument("--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)")
103
- parser.add_argument("--dna_ids", required=True, help="IDs file for DNA embeddings (e.g., peak*.ids)")
104
- parser.add_argument("--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)")
105
- parser.add_argument("--tf_ids", required=True, help="IDs file for TF embeddings (e.g., sp|...|... ids)")
106
- parser.add_argument("--out_dir", required=True, help="Output directory")
107
- parser.add_argument("--neg_per_positive", type=int, default=2, help="Negatives per positive (half same-TF, half same-DNA)")
108
- parser.add_argument("--seed", type=int, default=42)
109
- args = parser.parse_args()
110
-
111
- random.seed(args.seed)
112
- out_dir = Path(args.out_dir)
113
- out_dir.mkdir(parents=True, exist_ok=True)
114
-
115
- # Load final.csv
116
- df = pd.read_csv(args.final_csv, dtype=str)
117
- if "TF_id" not in df.columns or "dna_sequence" not in df.columns:
118
- raise RuntimeError("final.csv must have columns TF_id and dna_sequence")
119
-
120
- # Assign dna_id (unique per dna_sequence)
121
- unique_seqs = df["dna_sequence"].drop_duplicates().tolist()
122
- seq_to_id = {seq: f"peak{i}" for i, seq in enumerate(unique_seqs)}
123
- df["dna_id"] = df["dna_sequence"].map(seq_to_id)
124
- enriched_csv = out_dir / "final_with_dna_id.csv"
125
- df.to_csv(enriched_csv, index=False)
126
- print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
127
-
128
- # Split embeddings into per-item files
129
- print(f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}")
130
- dna_map = split_embeddings(args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna")
131
- print(f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})")
132
- print(f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}")
133
- tf_map = split_embeddings(args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf")
134
- print(f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})")
135
-
136
- # Build gene-symbol normalized map
137
- tf_symbol_map = build_tf_symbol_map(tf_map)
138
- print(f"[i] TF symbol map keys (sample): {list(tf_symbol_map.keys())[:30]}")
139
-
140
- # Diagnostic overlaps
141
- norm_tf_in_final = set(t.split("_seq")[0].upper() for t in df["TF_id"].unique())
142
- available_tf_symbols = set(tf_symbol_map.keys())
143
- intersect_tf = norm_tf_in_final & available_tf_symbols
144
- print(f"[i] Unique normalized TF symbols in final.csv: {len(norm_tf_in_final)}")
145
- print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
146
- print(f"[i] Intersection count: {len(intersect_tf)}")
147
- if len(intersect_tf) == 0:
148
- print("[ERROR] No overlap between normalized TF_id and TF embedding symbols.", file=sys.stderr)
149
- print("Sample normalized TFs from final.csv:", sorted(list(norm_tf_in_final))[:30], file=sys.stderr)
150
- print("Sample available TF symbols:", sorted(list(available_tf_symbols))[:30], file=sys.stderr)
151
- sys.exit(1)
152
-
153
- dna_ids_final = set(df["dna_id"].unique())
154
- available_dna_ids = set(dna_map.keys())
155
- intersect_dna = dna_ids_final & available_dna_ids
156
- print(f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}")
157
- if len(intersect_dna) == 0:
158
- print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
159
- sys.exit(1)
160
-
161
- # Build positive pairs
162
- positives = []
163
- for _, row in df.iterrows():
164
- tf_raw = row["TF_id"]
165
- tf_symbol = tf_raw.split("_seq")[0].upper()
166
- dnaid = row["dna_id"]
167
- if tf_symbol not in tf_symbol_map:
168
- continue
169
- if dnaid not in dna_map:
170
- continue
171
- # pick the first embedding for that symbol
172
- tf_embedding_path = tf_symbol_map[tf_symbol][0]
173
- positives.append((tf_embedding_path, dna_map[dnaid], 1))
174
- print(f"[i] Constructed {len(positives)} positive pairs after TF symbol resolution")
175
-
176
- if len(positives) == 0:
177
- print("[ERROR] No positive pairs could be constructed; aborting.", file=sys.stderr)
178
- sys.exit(1)
179
-
180
- # Build negative samples
181
- all_tf_symbols = sorted(tf_symbol_map.keys())
182
- all_dnaids = sorted(dna_map.keys())
183
- positive_set = set()
184
- for tf_path, dna_path, _ in positives:
185
- tf_key = tf_key_from_path(tf_path)
186
- dna_key = dna_key_from_path(dna_path)
187
- positive_set.add((tf_key, dna_key))
188
-
189
- negatives = []
190
- half = args.neg_per_positive // 2
191
- for tf_path, dna_path, _ in positives:
192
- tf_key = tf_key_from_path(tf_path)
193
- dna_key = dna_key_from_path(dna_path)
194
- # same TF, different DNA
195
- for _ in range(half):
196
- candidate_dna = random.choice(all_dnaids)
197
- if candidate_dna == dna_key or (tf_key, candidate_dna) in positive_set:
198
- continue
199
- negatives.append((tf_path, dna_map[candidate_dna], 0))
200
- # same DNA, different TF
201
- for _ in range(half):
202
- candidate_tf_symbol = random.choice(all_tf_symbols)
203
- if candidate_tf_symbol == tf_key or (candidate_tf_symbol, dna_key) in positive_set:
204
- continue
205
- # pick its first embedding
206
- candidate_tf_path = tf_symbol_map[candidate_tf_symbol][0]
207
- negatives.append((candidate_tf_path, dna_map[dnaid], 0))
208
-
209
- print(f"[i] Sampled {len(negatives)} negatives (neg_per_positive={args.neg_per_positive})")
210
-
211
- # Write pair list
212
- pair_list_path = out_dir / "pair_list.tsv"
213
- with open(pair_list_path, "w") as f:
214
- for binder_path, glm_path, label in positives + negatives:
215
- # binder=TF, glm=DNA
216
- f.write(f"{binder_path}\t{glm_path}\t{label}\n")
217
- print(f"[i] Wrote {len(positives)+len(negatives)} examples to {pair_list_path}")
218
-
219
- if __name__ == "__main__":
220
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/classifier/model/make_peak_fasta.py DELETED
@@ -1,13 +0,0 @@
1
- import pandas as pd
2
- from pathlib import Path
3
-
4
- df = pd.read_csv("/home/a03-akrishna/DPACMAN/data_files/processed/final.csv", dtype=str) # adjust path if needed
5
- # get unique sequences
6
- uniq = df[["dna_sequence"]].drop_duplicates().reset_index(drop=True)
7
- # make headers: e.g., peak0, peak1, ...
8
- out_fa = Path("binding_peaks_unique.fa")
9
- with open(out_fa, "w") as f:
10
- for i, seq in enumerate(uniq["dna_sequence"]):
11
- header = f">peak{i}"
12
- f.write(f"{header}\n{seq}\n")
13
- print(f"Wrote {len(uniq)} unique binding sequences to {out_fa}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/classifier/model_tmp/clustering_data.py CHANGED
@@ -12,12 +12,14 @@ from collections import defaultdict
12
  # Original helpers (kept; some lightly edited/commented where needed)
13
  # ─────────────────────────────────────────────────────────────────────────
14
 
 
15
  def read_ids_file(p):
16
  p = Path(p)
17
  if not p.exists():
18
  raise FileNotFoundError(f"IDs file not found: {p}")
19
  return [line.strip() for line in p.open() if line.strip()]
20
 
 
21
  def split_embeddings(emb_path, ids_path, out_dir, prefix):
22
  out_dir = Path(out_dir)
23
  out_dir.mkdir(parents=True, exist_ok=True)
@@ -38,12 +40,17 @@ def split_embeddings(emb_path, ids_path, out_dir, prefix):
38
 
39
  ids = read_ids_file(ids_path)
40
  if len(ids) != emb.shape[0]:
41
- print(f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", file=sys.stderr)
 
 
 
42
 
43
  mapping = {}
44
  for i, ident in enumerate(ids):
45
  if i >= emb.shape[0]:
46
- print(f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr)
 
 
47
  continue
48
  arr = emb[i]
49
  out_file = out_dir / f"{prefix}_{ident}.npy"
@@ -51,6 +58,7 @@ def split_embeddings(emb_path, ids_path, out_dir, prefix):
51
  mapping[ident] = str(out_file)
52
  return mapping
53
 
 
54
  def extract_symbol_from_tf_id(full_id: str) -> str:
55
  """
56
  Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
@@ -67,6 +75,7 @@ def extract_symbol_from_tf_id(full_id: str) -> str:
67
  symbol = genepart.split("_")[0]
68
  return symbol.upper()
69
 
 
70
  def build_tf_symbol_map(tf_map):
71
  """
72
  Build mapping gene_symbol -> list of embedding paths.
@@ -77,6 +86,7 @@ def build_tf_symbol_map(tf_map):
77
  symbol_map.setdefault(symbol, []).append(path)
78
  return symbol_map
79
 
 
80
  def tf_key_from_path(path: str) -> str:
81
  """
82
  Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
@@ -89,6 +99,7 @@ def tf_key_from_path(path: str) -> str:
89
  rest = stem
90
  return extract_symbol_from_tf_id(rest)
91
 
 
92
  def dna_key_from_path(path: str) -> str:
93
  """
94
  Given .../dna_peak42.npy -> 'peak42'
@@ -100,10 +111,12 @@ def dna_key_from_path(path: str) -> str:
100
  rest = stem
101
  return rest
102
 
 
103
  # ─────────────────────────────────────────────────────────────────────────
104
  # New helpers for MMseqs clustering & cluster-level splitting
105
  # ─────────────────────────────────────────────────────────────────────────
106
 
 
107
  def write_dna_fasta(df: pd.DataFrame, out_fasta: Path) -> None:
108
  """
109
  Write unique DNA sequences to FASTA using dna_id as header.
@@ -116,6 +129,7 @@ def write_dna_fasta(df: pd.DataFrame, out_fasta: Path) -> None:
116
  seq = str(row["dna_sequence"]).upper().replace(" ", "").replace("\n", "")
117
  f.write(f">{did}\n{seq}\n")
118
 
 
119
  def run_mmseqs_easy_cluster(
120
  mmseqs_bin: str,
121
  fasta: Path,
@@ -133,11 +147,17 @@ def run_mmseqs_easy_cluster(
133
  out_prefix.parent.mkdir(parents=True, exist_ok=True)
134
 
135
  cmd = [
136
- mmseqs_bin, "easy-cluster",
137
- str(fasta), str(out_prefix), str(tmp_dir),
138
- "--min-seq-id", str(min_seq_id),
139
- "-c", str(coverage),
140
- "--cov-mode", str(cov_mode),
 
 
 
 
 
 
141
  # You can add performance flags here if needed, e.g.:
142
  # "--threads", "8"
143
  ]
@@ -157,14 +177,24 @@ def run_mmseqs_easy_cluster(
157
  cl_db = Path(str(out_prefix) + "_cluster")
158
  out_tsv = Path(str(out_prefix) + "_fallback_cluster.tsv")
159
  if in_db.exists() and cl_db.exists():
160
- cmd2 = [mmseqs_bin, "createtsv", str(in_db), str(in_db), str(cl_db), str(out_tsv)]
 
 
 
 
 
 
 
161
  print("[i] Creating TSV via createtsv:", " ".join(cmd2), flush=True)
162
  subprocess.run(cmd2, check=True)
163
  if out_tsv.exists():
164
  return out_tsv
165
 
166
- raise FileNotFoundError("Could not locate clusters TSV from mmseqs. "
167
- "Expected {default_tsv} or createtsv fallback.")
 
 
 
168
 
169
  def parse_mmseqs_clusters(tsv_path: Path) -> dict:
170
  """
@@ -174,7 +204,7 @@ def parse_mmseqs_clusters(tsv_path: Path) -> dict:
174
  with open(tsv_path) as f:
175
  for line in f:
176
  parts = line.rstrip("\n").split("\t")
177
- if len(parts) < 2:
178
  continue
179
  rep, member = parts[0], parts[1]
180
  mapping[member] = rep
@@ -183,10 +213,10 @@ def parse_mmseqs_clusters(tsv_path: Path) -> dict:
183
  mapping[rep] = rep
184
  return mapping
185
 
186
- def assign_clusters_to_splits(cluster_rep_to_members: dict,
187
- val_frac: float,
188
- test_frac: float,
189
- seed: int = 42):
190
  """
191
  cluster_rep_to_members: dict[rep] = [members...]
192
  Returns: dict with keys 'train','val','test' mapping to sets of dna_id.
@@ -208,38 +238,63 @@ def assign_clusters_to_splits(cluster_rep_to_members: dict,
208
  c = len(members)
209
  # Fill val first, then test, then train
210
  if cur_val + c <= target_val:
211
- val_ids.update(members); cur_val += c
 
212
  elif cur_test + c <= target_test:
213
- test_ids.update(members); cur_test += c
 
214
  else:
215
  train_ids.update(members)
216
 
217
  return {"train": train_ids, "val": val_ids, "test": test_ids}
218
 
 
219
  # ─────────────────────────────────────────────────────────────────────────
220
  # Main
221
  # ─────────────────────────────────────────────────────────────────────────
222
 
 
223
  def main():
224
  parser = argparse.ArgumentParser(
225
  description="Build TF-DNA pair lists with MMseqs clustering on DNA to prevent split leakage."
226
  )
227
- parser.add_argument("--final_csv", required=True, help="final.csv with TF_id and dna_sequence")
228
- parser.add_argument("--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)")
229
- parser.add_argument("--dna_ids", required=True, help="IDs file for DNA embeddings (peak*.ids)")
230
- parser.add_argument("--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)")
231
- parser.add_argument("--tf_ids", required=True, help="IDs file for TF embeddings (sp|... ids)")
 
 
 
 
 
 
 
 
 
 
232
  parser.add_argument("--out_dir", required=True, help="Output directory")
233
  parser.add_argument("--seed", type=int, default=42)
234
 
235
  # NEW: MMseqs options & split fractions
236
  parser.add_argument("--mmseqs_bin", default="mmseqs", help="Path to mmseqs binary")
237
- parser.add_argument("--min_seq_id", type=float, default=0.9, help="MMseqs --min-seq-id")
238
- parser.add_argument("--cov", type=float, default=0.8, help="MMseqs -c coverage fraction")
239
- parser.add_argument("--cov_mode", type=int, default=1, help="MMseqs --cov-mode (1 = coverage of target)")
 
 
 
 
 
 
 
 
 
240
  parser.add_argument("--val_frac", type=float, default=0.10)
241
  parser.add_argument("--test_frac", type=float, default=0.10)
242
- parser.add_argument("--tmp_dir", default=None, help="MMseqs tmp dir (defaults to out_dir/tmp)")
 
 
243
  args = parser.parse_args()
244
 
245
  random.seed(args.seed)
@@ -260,12 +315,24 @@ def main():
260
  print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
261
 
262
  # Split embeddings into per-item files (unchanged)
263
- print(f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}")
264
- dna_map = split_embeddings(args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna")
265
- print(f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})")
266
- print(f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}")
267
- tf_map = split_embeddings(args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf")
268
- print(f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})")
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
  # Build gene-symbol normalized map
271
  tf_symbol_map = build_tf_symbol_map(tf_map)
@@ -279,15 +346,28 @@ def main():
279
  print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
280
  print(f"[i] Intersection count: {len(intersect_tf)}")
281
  if len(intersect_tf) == 0:
282
- print("[ERROR] No overlap between normalized TF_id and TF embedding symbols.", file=sys.stderr)
283
- print("Sample normalized TFs from final.csv:", sorted(list(norm_tf_in_final))[:30], file=sys.stderr)
284
- print("Sample available TF symbols:", sorted(list(available_tf_symbols))[:30], file=sys.stderr)
 
 
 
 
 
 
 
 
 
 
 
285
  sys.exit(1)
286
 
287
  dna_ids_final = set(df["dna_id"].unique())
288
  available_dna_ids = set(dna_map.keys())
289
  intersect_dna = dna_ids_final & available_dna_ids
290
- print(f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}")
 
 
291
  if len(intersect_dna) == 0:
292
  print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
293
  sys.exit(1)
@@ -295,7 +375,9 @@ def main():
295
  # ── NEW: MMseqs clustering on DNA sequences ───────────────────────────
296
  fasta_path = out_dir / "dna_unique.fasta"
297
  write_dna_fasta(df, fasta_path)
298
- print(f"[i] Wrote FASTA with {df['dna_id'].nunique()} unique sequences → {fasta_path}")
 
 
299
 
300
  tmp_dir = Path(args.tmp_dir) if args.tmp_dir else (out_dir / "mmseqs_tmp")
301
  cluster_prefix = out_dir / "mmseqs_dna_clusters"
@@ -310,7 +392,7 @@ def main():
310
  )
311
 
312
  # Parse clusters
313
- member_to_rep = parse_mmseqs_clusters(clusters_tsv) # dna_id -> rep_id
314
  # Build rep -> members list
315
  rep_to_members = defaultdict(list)
316
  for member, rep in member_to_rep.items():
@@ -331,10 +413,9 @@ def main():
331
  print(f"[i] Wrote {out_dir / 'final_with_dna_id_and_cluster.csv'}")
332
 
333
  # Assign entire clusters to splits
334
- splits = assign_clusters_to_splits(rep_to_members,
335
- val_frac=args.val_frac,
336
- test_frac=args.test_frac,
337
- seed=args.seed)
338
  for k in ["train", "val", "test"]:
339
  print(f"[i] {k}: {len(splits[k])} dna_ids")
340
 
@@ -354,14 +435,22 @@ def main():
354
 
355
  # decide split by dna_id cluster assignment
356
  if dnaid in splits["train"]:
357
- positives_by_split["train"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
 
 
358
  elif dnaid in splits["val"]:
359
- positives_by_split["val"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
 
 
360
  elif dnaid in splits["test"]:
361
- positives_by_split["test"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
 
 
362
  pos_count += 1
363
 
364
- print(f"[i] Constructed positives across splits (rows in final.csv iterated: {len(df)})")
 
 
365
  for k in ["train", "val", "test"]:
366
  print(f"[i] positives[{k}] = {len(positives_by_split[k])}")
367
 
@@ -373,11 +462,14 @@ def main():
373
  for split in ["train", "val", "test"]:
374
  out_tsv = out_dir / f"pair_list_{split}.tsv"
375
  with open(out_tsv, "w") as f:
376
- for binder_path, glm_path, label in positives_by_split[split]: # + negatives if you add later
 
 
377
  f.write(f"{binder_path}\t{glm_path}\t{label}\n")
378
  print(f"[i] Wrote {len(positives_by_split[split])} examples to {out_tsv}")
379
 
380
  print("✅ Done. Cluster-aware splits ready.")
381
 
 
382
  if __name__ == "__main__":
383
  main()
 
12
  # Original helpers (kept; some lightly edited/commented where needed)
13
  # ─────────────────────────────────────────────────────────────────────────
14
 
15
+
16
  def read_ids_file(p):
17
  p = Path(p)
18
  if not p.exists():
19
  raise FileNotFoundError(f"IDs file not found: {p}")
20
  return [line.strip() for line in p.open() if line.strip()]
21
 
22
+
23
  def split_embeddings(emb_path, ids_path, out_dir, prefix):
24
  out_dir = Path(out_dir)
25
  out_dir.mkdir(parents=True, exist_ok=True)
 
40
 
41
  ids = read_ids_file(ids_path)
42
  if len(ids) != emb.shape[0]:
43
+ print(
44
+ f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}",
45
+ file=sys.stderr,
46
+ )
47
 
48
  mapping = {}
49
  for i, ident in enumerate(ids):
50
  if i >= emb.shape[0]:
51
+ print(
52
+ f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr
53
+ )
54
  continue
55
  arr = emb[i]
56
  out_file = out_dir / f"{prefix}_{ident}.npy"
 
58
  mapping[ident] = str(out_file)
59
  return mapping
60
 
61
+
62
  def extract_symbol_from_tf_id(full_id: str) -> str:
63
  """
64
  Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
 
75
  symbol = genepart.split("_")[0]
76
  return symbol.upper()
77
 
78
+
79
  def build_tf_symbol_map(tf_map):
80
  """
81
  Build mapping gene_symbol -> list of embedding paths.
 
86
  symbol_map.setdefault(symbol, []).append(path)
87
  return symbol_map
88
 
89
+
90
  def tf_key_from_path(path: str) -> str:
91
  """
92
  Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
 
99
  rest = stem
100
  return extract_symbol_from_tf_id(rest)
101
 
102
+
103
  def dna_key_from_path(path: str) -> str:
104
  """
105
  Given .../dna_peak42.npy -> 'peak42'
 
111
  rest = stem
112
  return rest
113
 
114
+
115
  # ─────────────────────────────────────────────────────────────────────────
116
  # New helpers for MMseqs clustering & cluster-level splitting
117
  # ─────────────────────────────────────────────────────────────────────────
118
 
119
+
120
  def write_dna_fasta(df: pd.DataFrame, out_fasta: Path) -> None:
121
  """
122
  Write unique DNA sequences to FASTA using dna_id as header.
 
129
  seq = str(row["dna_sequence"]).upper().replace(" ", "").replace("\n", "")
130
  f.write(f">{did}\n{seq}\n")
131
 
132
+
133
  def run_mmseqs_easy_cluster(
134
  mmseqs_bin: str,
135
  fasta: Path,
 
147
  out_prefix.parent.mkdir(parents=True, exist_ok=True)
148
 
149
  cmd = [
150
+ mmseqs_bin,
151
+ "easy-cluster",
152
+ str(fasta),
153
+ str(out_prefix),
154
+ str(tmp_dir),
155
+ "--min-seq-id",
156
+ str(min_seq_id),
157
+ "-c",
158
+ str(coverage),
159
+ "--cov-mode",
160
+ str(cov_mode),
161
  # You can add performance flags here if needed, e.g.:
162
  # "--threads", "8"
163
  ]
 
177
  cl_db = Path(str(out_prefix) + "_cluster")
178
  out_tsv = Path(str(out_prefix) + "_fallback_cluster.tsv")
179
  if in_db.exists() and cl_db.exists():
180
+ cmd2 = [
181
+ mmseqs_bin,
182
+ "createtsv",
183
+ str(in_db),
184
+ str(in_db),
185
+ str(cl_db),
186
+ str(out_tsv),
187
+ ]
188
  print("[i] Creating TSV via createtsv:", " ".join(cmd2), flush=True)
189
  subprocess.run(cmd2, check=True)
190
  if out_tsv.exists():
191
  return out_tsv
192
 
193
+ raise FileNotFoundError(
194
+ "Could not locate clusters TSV from mmseqs. "
195
+ "Expected {default_tsv} or createtsv fallback."
196
+ )
197
+
198
 
199
  def parse_mmseqs_clusters(tsv_path: Path) -> dict:
200
  """
 
204
  with open(tsv_path) as f:
205
  for line in f:
206
  parts = line.rstrip("\n").split("\t")
207
+ if len(parts) < 2:
208
  continue
209
  rep, member = parts[0], parts[1]
210
  mapping[member] = rep
 
213
  mapping[rep] = rep
214
  return mapping
215
 
216
+
217
+ def assign_clusters_to_splits(
218
+ cluster_rep_to_members: dict, val_frac: float, test_frac: float, seed: int = 42
219
+ ):
220
  """
221
  cluster_rep_to_members: dict[rep] = [members...]
222
  Returns: dict with keys 'train','val','test' mapping to sets of dna_id.
 
238
  c = len(members)
239
  # Fill val first, then test, then train
240
  if cur_val + c <= target_val:
241
+ val_ids.update(members)
242
+ cur_val += c
243
  elif cur_test + c <= target_test:
244
+ test_ids.update(members)
245
+ cur_test += c
246
  else:
247
  train_ids.update(members)
248
 
249
  return {"train": train_ids, "val": val_ids, "test": test_ids}
250
 
251
+
252
  # ─────────────────────────────────────────────────────────────────────────
253
  # Main
254
  # ─────────────────────────────────────────────────────────────────────────
255
 
256
+
257
  def main():
258
  parser = argparse.ArgumentParser(
259
  description="Build TF-DNA pair lists with MMseqs clustering on DNA to prevent split leakage."
260
  )
261
+ parser.add_argument(
262
+ "--final_csv", required=True, help="final.csv with TF_id and dna_sequence"
263
+ )
264
+ parser.add_argument(
265
+ "--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)"
266
+ )
267
+ parser.add_argument(
268
+ "--dna_ids", required=True, help="IDs file for DNA embeddings (peak*.ids)"
269
+ )
270
+ parser.add_argument(
271
+ "--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)"
272
+ )
273
+ parser.add_argument(
274
+ "--tf_ids", required=True, help="IDs file for TF embeddings (sp|... ids)"
275
+ )
276
  parser.add_argument("--out_dir", required=True, help="Output directory")
277
  parser.add_argument("--seed", type=int, default=42)
278
 
279
  # NEW: MMseqs options & split fractions
280
  parser.add_argument("--mmseqs_bin", default="mmseqs", help="Path to mmseqs binary")
281
+ parser.add_argument(
282
+ "--min_seq_id", type=float, default=0.9, help="MMseqs --min-seq-id"
283
+ )
284
+ parser.add_argument(
285
+ "--cov", type=float, default=0.8, help="MMseqs -c coverage fraction"
286
+ )
287
+ parser.add_argument(
288
+ "--cov_mode",
289
+ type=int,
290
+ default=1,
291
+ help="MMseqs --cov-mode (1 = coverage of target)",
292
+ )
293
  parser.add_argument("--val_frac", type=float, default=0.10)
294
  parser.add_argument("--test_frac", type=float, default=0.10)
295
+ parser.add_argument(
296
+ "--tmp_dir", default=None, help="MMseqs tmp dir (defaults to out_dir/tmp)"
297
+ )
298
  args = parser.parse_args()
299
 
300
  random.seed(args.seed)
 
315
  print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
316
 
317
  # Split embeddings into per-item files (unchanged)
318
+ print(
319
+ f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}"
320
+ )
321
+ dna_map = split_embeddings(
322
+ args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna"
323
+ )
324
+ print(
325
+ f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})"
326
+ )
327
+ print(
328
+ f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}"
329
+ )
330
+ tf_map = split_embeddings(
331
+ args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf"
332
+ )
333
+ print(
334
+ f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})"
335
+ )
336
 
337
  # Build gene-symbol normalized map
338
  tf_symbol_map = build_tf_symbol_map(tf_map)
 
346
  print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
347
  print(f"[i] Intersection count: {len(intersect_tf)}")
348
  if len(intersect_tf) == 0:
349
+ print(
350
+ "[ERROR] No overlap between normalized TF_id and TF embedding symbols.",
351
+ file=sys.stderr,
352
+ )
353
+ print(
354
+ "Sample normalized TFs from final.csv:",
355
+ sorted(list(norm_tf_in_final))[:30],
356
+ file=sys.stderr,
357
+ )
358
+ print(
359
+ "Sample available TF symbols:",
360
+ sorted(list(available_tf_symbols))[:30],
361
+ file=sys.stderr,
362
+ )
363
  sys.exit(1)
364
 
365
  dna_ids_final = set(df["dna_id"].unique())
366
  available_dna_ids = set(dna_map.keys())
367
  intersect_dna = dna_ids_final & available_dna_ids
368
+ print(
369
+ f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}"
370
+ )
371
  if len(intersect_dna) == 0:
372
  print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
373
  sys.exit(1)
 
375
  # ── NEW: MMseqs clustering on DNA sequences ───────────────────────────
376
  fasta_path = out_dir / "dna_unique.fasta"
377
  write_dna_fasta(df, fasta_path)
378
+ print(
379
+ f"[i] Wrote FASTA with {df['dna_id'].nunique()} unique sequences → {fasta_path}"
380
+ )
381
 
382
  tmp_dir = Path(args.tmp_dir) if args.tmp_dir else (out_dir / "mmseqs_tmp")
383
  cluster_prefix = out_dir / "mmseqs_dna_clusters"
 
392
  )
393
 
394
  # Parse clusters
395
+ member_to_rep = parse_mmseqs_clusters(clusters_tsv) # dna_id -> rep_id
396
  # Build rep -> members list
397
  rep_to_members = defaultdict(list)
398
  for member, rep in member_to_rep.items():
 
413
  print(f"[i] Wrote {out_dir / 'final_with_dna_id_and_cluster.csv'}")
414
 
415
  # Assign entire clusters to splits
416
+ splits = assign_clusters_to_splits(
417
+ rep_to_members, val_frac=args.val_frac, test_frac=args.test_frac, seed=args.seed
418
+ )
 
419
  for k in ["train", "val", "test"]:
420
  print(f"[i] {k}: {len(splits[k])} dna_ids")
421
 
 
435
 
436
  # decide split by dna_id cluster assignment
437
  if dnaid in splits["train"]:
438
+ positives_by_split["train"].append(
439
+ (tf_embedding_path, dnaid_to_path[dnaid], 1)
440
+ )
441
  elif dnaid in splits["val"]:
442
+ positives_by_split["val"].append(
443
+ (tf_embedding_path, dnaid_to_path[dnaid], 1)
444
+ )
445
  elif dnaid in splits["test"]:
446
+ positives_by_split["test"].append(
447
+ (tf_embedding_path, dnaid_to_path[dnaid], 1)
448
+ )
449
  pos_count += 1
450
 
451
+ print(
452
+ f"[i] Constructed positives across splits (rows in final.csv iterated: {len(df)})"
453
+ )
454
  for k in ["train", "val", "test"]:
455
  print(f"[i] positives[{k}] = {len(positives_by_split[k])}")
456
 
 
462
  for split in ["train", "val", "test"]:
463
  out_tsv = out_dir / f"pair_list_{split}.tsv"
464
  with open(out_tsv, "w") as f:
465
+ for binder_path, glm_path, label in positives_by_split[
466
+ split
467
+ ]: # + negatives if you add later
468
  f.write(f"{binder_path}\t{glm_path}\t{label}\n")
469
  print(f"[i] Wrote {len(positives_by_split[split])} examples to {out_tsv}")
470
 
471
  print("✅ Done. Cluster-aware splits ready.")
472
 
473
+
474
  if __name__ == "__main__":
475
  main()
dpacman/classifier/model_tmp/compress_embeddings.py CHANGED
@@ -7,6 +7,7 @@ import numpy as np
7
  import torch
8
  from torch import nn
9
 
 
10
  class EmbeddingCompressor(nn.Module):
11
  def __init__(self, input_dim: int = 1280, output_dim: int = 256):
12
  super().__init__()
@@ -19,26 +20,33 @@ class EmbeddingCompressor(nn.Module):
19
  """
20
  if x.dim() == 2:
21
  # single example: mean over tokens
22
- x = x.mean(dim=0, keepdim=True) # → (1, input_dim)
23
  else:
24
  # batch: mean over tokens
25
- x = x.mean(dim=1) # → (batch, input_dim)
26
- return self.fc(x) # → (batch, output_dim)
 
27
 
28
  def compress_file(in_path: str, out_path: str, model: EmbeddingCompressor):
29
- arr = np.load(in_path) # shape (L, D) or (batch, L, D)
30
  tensor = torch.from_numpy(arr).float()
31
  with torch.no_grad():
32
- compressed = model(tensor) # → (batch, 256)
33
  out = compressed.cpu().numpy()
34
  np.save(out_path, out)
35
  print(f"Saved {out_path}")
36
 
 
37
  if __name__ == "__main__":
38
  import argparse
 
39
  parser = argparse.ArgumentParser(description="Compress ESM embeddings to 256­d")
40
- parser.add_argument("--input_glob", type=str, required=True,
41
- help="Glob for your .npy ESM embeddings (e.g. data/esm_*.npy)")
 
 
 
 
42
  parser.add_argument("--output_dir", type=str, required=True)
43
  parser.add_argument("--esm_dim", type=int, default=1280)
44
  parser.add_argument("--out_dim", type=int, default=256)
 
7
  import torch
8
  from torch import nn
9
 
10
+
11
  class EmbeddingCompressor(nn.Module):
12
  def __init__(self, input_dim: int = 1280, output_dim: int = 256):
13
  super().__init__()
 
20
  """
21
  if x.dim() == 2:
22
  # single example: mean over tokens
23
+ x = x.mean(dim=0, keepdim=True) # → (1, input_dim)
24
  else:
25
  # batch: mean over tokens
26
+ x = x.mean(dim=1) # → (batch, input_dim)
27
+ return self.fc(x) # → (batch, output_dim)
28
+
29
 
30
  def compress_file(in_path: str, out_path: str, model: EmbeddingCompressor):
31
+ arr = np.load(in_path) # shape (L, D) or (batch, L, D)
32
  tensor = torch.from_numpy(arr).float()
33
  with torch.no_grad():
34
+ compressed = model(tensor) # → (batch, 256)
35
  out = compressed.cpu().numpy()
36
  np.save(out_path, out)
37
  print(f"Saved {out_path}")
38
 
39
+
40
  if __name__ == "__main__":
41
  import argparse
42
+
43
  parser = argparse.ArgumentParser(description="Compress ESM embeddings to 256­d")
44
+ parser.add_argument(
45
+ "--input_glob",
46
+ type=str,
47
+ required=True,
48
+ help="Glob for your .npy ESM embeddings (e.g. data/esm_*.npy)",
49
+ )
50
  parser.add_argument("--output_dir", type=str, required=True)
51
  parser.add_argument("--esm_dim", type=int, default=1280)
52
  parser.add_argument("--out_dim", type=int, default=256)
dpacman/classifier/model_tmp/compute_embeddings.py CHANGED
@@ -14,6 +14,7 @@ Usage example (DNA + protein in one go):
14
  --out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \
15
  --device cuda
16
  """
 
17
  import os
18
  import re
19
  import argparse
@@ -28,6 +29,7 @@ import time
28
 
29
  # ---- model wrappers ----
30
 
 
31
  class CaduceusEmbedder:
32
  def __init__(self, device, chunk_size=131_072, overlap=0):
33
  """
@@ -39,12 +41,14 @@ class CaduceusEmbedder:
39
  self.tokenizer = AutoTokenizer.from_pretrained(
40
  model_name, trust_remote_code=True
41
  )
42
- self.model = AutoModel.from_pretrained(
43
- model_name, trust_remote_code=True
44
- ).to(device).eval()
45
- self.device = device
 
 
46
  self.chunk_size = chunk_size
47
- self.step = chunk_size - overlap
48
 
49
  def embed(self, seqs):
50
  """
@@ -73,14 +77,13 @@ class CaduceusEmbedder:
73
  return_tensors="pt",
74
  padding=False,
75
  truncation=True,
76
- max_length=self.chunk_size
77
  ).to(self.device)
78
  with torch.no_grad():
79
  out = self.model(**toks).last_hidden_state # (1, L, D)
80
- outputs.append(out.cpu().numpy()[0]) # (L, D)
81
  return outputs # list of variable-length (L_i, D) arrays
82
 
83
-
84
  def benchmark(self, lengths=None):
85
  """
86
  Time embedding on single-sequence of various lengths.
@@ -101,10 +104,17 @@ class CaduceusEmbedder:
101
  t1 = time.perf_counter()
102
  print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms")
103
 
 
104
  class SegmentNTEmbedder:
105
  def __init__(self, device):
106
- self.tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True)
107
- self.model = AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True).to(device).eval()
 
 
 
 
 
 
108
  self.device = device
109
 
110
  def _adjust_length(self, input_ids):
@@ -113,7 +123,12 @@ class SegmentNTEmbedder:
113
  remainder = (excl) % 4
114
  if remainder != 0:
115
  pad_needed = 4 - remainder
116
- pad_tensor = torch.full((bs, pad_needed), self.tokenizer.pad_token_id, dtype=input_ids.dtype, device=input_ids.device)
 
 
 
 
 
117
  input_ids = torch.cat([input_ids, pad_tensor], dim=1)
118
  return input_ids
119
 
@@ -135,7 +150,7 @@ class SegmentNTEmbedder:
135
  attention_mask = input_ids != self.tokenizer.pad_token_id
136
 
137
  input_ids = self._adjust_length(input_ids)
138
- attention_mask = (input_ids != self.tokenizer.pad_token_id)
139
 
140
  with torch.no_grad():
141
  outs = self.model(
@@ -161,19 +176,26 @@ class SegmentNTEmbedder:
161
 
162
  class DNABertEmbedder:
163
  def __init__(self, device):
164
- self.tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True)
165
- self.model = AutoModel.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True).to(device)
166
- self.device = device
 
 
 
 
167
 
168
  def embed(self, seqs):
169
  embs = []
170
  for s in seqs:
171
- tokens = self.tokenizer(s, return_tensors="pt", padding=True)["input_ids"].to(self.device)
 
 
172
  with torch.no_grad():
173
  out = self.model(tokens).last_hidden_state.mean(1)
174
  embs.append(out.cpu().numpy())
175
  return np.vstack(embs)
176
 
 
177
  class NucleotideTransformerEmbedder:
178
  def __init__(self, device):
179
  # HF “feature-extraction” returns a list of (L, D) arrays for each input
@@ -181,7 +203,9 @@ class NucleotideTransformerEmbedder:
181
  self.pipe = pipeline(
182
  "feature-extraction",
183
  model="InstaDeepAI/nucleotide-transformer-500m-1000g",
184
- device= -1 if device=="cpu" else 0 # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0}
 
 
185
  )
186
 
187
  def embed(self, seqs):
@@ -191,8 +215,9 @@ class NucleotideTransformerEmbedder:
191
  """
192
  all_embeddings = self.pipe(seqs, truncation=True, padding=True)
193
  # all_embeddings is a List of shape (L, D) arrays
194
- pooled = [ np.mean(x, axis=0) for x in all_embeddings ]
195
- return np.vstack(pooled)
 
196
 
197
  # class ESMEmbedder:
198
  # def __init__(self, device):
@@ -225,7 +250,9 @@ class ESMEmbedder:
225
  self.batch_converter = self.alphabet.get_batch_converter()
226
  self.model.to(device).eval()
227
  # determine max length: esm2 models vary; use default 1024 for esm1b
228
- self.max_len = 4096 if self.is_esm2 else 1024 # adjust if your esm2 variant has explicit limit
 
 
229
  # for chunking: reserve 2 tokens if model uses BOS/EOS
230
  self.chunk_size = self.max_len - 2
231
  self.overlap = self.chunk_size // 4 # 25% overlap to smooth boundaries
@@ -280,7 +307,7 @@ class ESMEmbedder:
280
 
281
  # class ESMDBPEmbedder:
282
  # def __init__(self, device):
283
- # base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
284
  # model_path = (
285
  # Path(__file__).resolve().parent.parent
286
  # / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
@@ -310,12 +337,15 @@ class ESMEmbedder:
310
  # # skip start/end tokens
311
  # return reps[:, 1:-1].mean(1).cpu().numpy()
312
 
 
313
  class ESMDBPEmbedder:
314
  def __init__(self, device):
315
  base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
316
  model_path = (
317
  Path(__file__).resolve().parent.parent
318
- / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
 
 
319
  )
320
  checkpoint = torch.load(model_path, map_location="cpu")
321
  clean_sd = {}
@@ -372,6 +402,7 @@ class ESMDBPEmbedder:
372
  all_embeddings.append(seq_vec.cpu().numpy())
373
  return np.vstack(all_embeddings)
374
 
 
375
  class GPNEmbedder:
376
  def __init__(self, device):
377
  model_name = "songlab/gpn-msa-sapiens"
@@ -383,16 +414,14 @@ class GPNEmbedder:
383
 
384
  def embed(self, seqs):
385
  inputs = self.tokenizer(
386
- seqs,
387
- return_tensors="pt",
388
- padding=True,
389
- truncation=True
390
  ).to(self.device)
391
 
392
  with torch.no_grad():
393
  last_hidden = self.model(**inputs).last_hidden_state
394
  return last_hidden.mean(dim=1).cpu().numpy()
395
 
 
396
  class ProGenEmbedder:
397
  def __init__(self, device):
398
  model_name = "jinyuan22/ProGen2-base"
@@ -402,29 +431,36 @@ class ProGenEmbedder:
402
 
403
  def embed(self, seqs):
404
  inputs = self.tokenizer(
405
- seqs,
406
- return_tensors="pt",
407
- padding=True,
408
- truncation=True
409
  ).to(self.device)
410
  with torch.no_grad():
411
  last_hidden = self.model(**inputs).last_hidden_state
412
  return last_hidden.mean(dim=1).cpu().numpy()
413
 
 
414
  # ---- main pipeline ----
415
 
 
416
  def get_embedder(name, device, for_dna=True):
417
  name = name.lower()
418
  if for_dna:
419
- if name=="caduceus": return CaduceusEmbedder(device)
420
- if name=="dnabert": return DNABertEmbedder(device)
421
- if name=="nucleotide": return NucleotideTransformerEmbedder(device)
422
- if name=="gpn": return GPNEmbedder(device)
423
- if name=="segmentnt": return SegmentNTEmbedder(device)
 
 
 
 
 
424
  else:
425
- if name in ("esm",): return ESMEmbedder(device)
426
- if name in ("esm-dbp","esm_dbp"): return ESMDBPEmbedder(device)
427
- if name=="progen": return ProGenEmbedder(device)
 
 
 
428
  raise ValueError(f"Unknown model {name} (for_dna={for_dna})")
429
 
430
 
@@ -446,20 +482,28 @@ def pad_token_embeddings(list_of_arrays, pad_value=0.0):
446
  mask[i, :L] = True
447
  return padded, mask
448
 
 
449
  def embed_and_save(seqs, ids, embedder, out_path):
450
  embs = embedder.embed(seqs)
451
 
452
  # Decide whether we got variable-length per-token outputs (list of (L, D))
453
- is_variable_token = isinstance(embs, (list, tuple)) and len(embs) > 0 and hasattr(embs[0], "shape") and embs[0].ndim == 2
 
 
 
 
 
454
 
455
  if is_variable_token:
456
  # pad to (N, L_max, D) + mask
457
  padded, mask = pad_token_embeddings(embs)
458
  # Save both embeddings and mask together in an .npz for convenience
459
- np.savez_compressed(out_path.with_suffix(".caduceus.npz"),
460
- embeddings=padded,
461
- mask=mask,
462
- ids=np.array(ids, dtype=object))
 
 
463
  else:
464
  # fixed shape output, e.g., pooled (N, D)
465
  array = np.vstack(embs) if isinstance(embs, list) else embs
@@ -468,17 +512,31 @@ def embed_and_save(seqs, ids, embedder, out_path):
468
  f.write("\n".join(ids))
469
 
470
 
471
- if __name__=="__main__":
472
 
473
  p = argparse.ArgumentParser()
474
- p.add_argument("--peak-fasta", default="binding_peaks_unique.fa", help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs")
475
- p.add_argument("--genome-json-dir", default=None, help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes")
476
- p.add_argument("--skip-dna", action="store_true", help="if set, skip the chromosome embedding step") #if glm embeddings successful but not plm embeddings
477
- p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
478
- p.add_argument("--chrom-model", default="caduceus")
479
- p.add_argument("--tf-model", default="esm-dbp")
480
- p.add_argument("--out-dir", default="data_files/processed/tfclust/hg38_tf/embeddings")
481
- p.add_argument("--device", default="cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
  args = p.parse_args()
483
 
484
  os.makedirs(args.out_dir, exist_ok=True)
@@ -495,7 +553,10 @@ if __name__=="__main__":
495
  for rec in SeqIO.parse(peak_fasta, "fasta"):
496
  peak_ids.append(rec.id)
497
  peak_seqs.append(str(rec.seq))
498
- print(f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}", flush=True)
 
 
 
499
  dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
500
  out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
501
  embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
@@ -503,7 +564,9 @@ if __name__=="__main__":
503
  # Legacy: load full chromosomes from JSONs (chr1–22, X, Y, M)
504
  genome_dir = Path(args.genome_json_dir)
505
  chrom_seqs, chrom_ids = [], []
506
- primary_pattern = re.compile(r"^hg38_chr(?:[1-9]|1[0-9]|2[0-2]|X|Y|M)\.json$")
 
 
507
  for j in sorted(genome_dir.iterdir()):
508
  if not primary_pattern.match(j.name):
509
  continue
@@ -519,7 +582,9 @@ if __name__=="__main__":
519
  if len(seq) > cutoff
520
  ]
521
  if long_chroms:
522
- print("⚠️ Chromosomes exceeding Caduceus max tokens ({}):".format(cutoff))
 
 
523
  for chrom, L in long_chroms:
524
  print(f" {chrom}: {L} bases")
525
  else:
@@ -529,10 +594,11 @@ if __name__=="__main__":
529
  out_chrom = Path(args.out_dir) / f"chrom_{args.chrom_model}.npy"
530
  embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom)
531
  else:
532
- raise ValueError("No input for DNA embedding: provide a peak FASTA (default binding_peaks_unique.fa) or set --genome-json-dir for chromosome JSONs.")
533
-
 
534
 
535
- #Load TF sequences
536
  tf_seqs, tf_ids = [], []
537
  for record in SeqIO.parse(args.tf_fasta, "fasta"):
538
  tf_ids.append(record.id)
@@ -543,4 +609,4 @@ if __name__=="__main__":
543
  out_tf = Path(args.out_dir) / f"tf_{args.tf_model}.npy"
544
  embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf)
545
 
546
- print("Done.")
 
14
  --out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \
15
  --device cuda
16
  """
17
+
18
  import os
19
  import re
20
  import argparse
 
29
 
30
  # ---- model wrappers ----
31
 
32
+
33
  class CaduceusEmbedder:
34
  def __init__(self, device, chunk_size=131_072, overlap=0):
35
  """
 
41
  self.tokenizer = AutoTokenizer.from_pretrained(
42
  model_name, trust_remote_code=True
43
  )
44
+ self.model = (
45
+ AutoModel.from_pretrained(model_name, trust_remote_code=True)
46
+ .to(device)
47
+ .eval()
48
+ )
49
+ self.device = device
50
  self.chunk_size = chunk_size
51
+ self.step = chunk_size - overlap
52
 
53
  def embed(self, seqs):
54
  """
 
77
  return_tensors="pt",
78
  padding=False,
79
  truncation=True,
80
+ max_length=self.chunk_size,
81
  ).to(self.device)
82
  with torch.no_grad():
83
  out = self.model(**toks).last_hidden_state # (1, L, D)
84
+ outputs.append(out.cpu().numpy()[0]) # (L, D)
85
  return outputs # list of variable-length (L_i, D) arrays
86
 
 
87
  def benchmark(self, lengths=None):
88
  """
89
  Time embedding on single-sequence of various lengths.
 
104
  t1 = time.perf_counter()
105
  print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms")
106
 
107
+
108
  class SegmentNTEmbedder:
109
  def __init__(self, device):
110
+ self.tokenizer = AutoTokenizer.from_pretrained(
111
+ "InstaDeepAI/segment_nt", trust_remote_code=True
112
+ )
113
+ self.model = (
114
+ AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True)
115
+ .to(device)
116
+ .eval()
117
+ )
118
  self.device = device
119
 
120
  def _adjust_length(self, input_ids):
 
123
  remainder = (excl) % 4
124
  if remainder != 0:
125
  pad_needed = 4 - remainder
126
+ pad_tensor = torch.full(
127
+ (bs, pad_needed),
128
+ self.tokenizer.pad_token_id,
129
+ dtype=input_ids.dtype,
130
+ device=input_ids.device,
131
+ )
132
  input_ids = torch.cat([input_ids, pad_tensor], dim=1)
133
  return input_ids
134
 
 
150
  attention_mask = input_ids != self.tokenizer.pad_token_id
151
 
152
  input_ids = self._adjust_length(input_ids)
153
+ attention_mask = input_ids != self.tokenizer.pad_token_id
154
 
155
  with torch.no_grad():
156
  outs = self.model(
 
176
 
177
  class DNABertEmbedder:
178
  def __init__(self, device):
179
+ self.tokenizer = AutoTokenizer.from_pretrained(
180
+ "zhihan1996/DNA_bert_6", trust_remote_code=True
181
+ )
182
+ self.model = AutoModel.from_pretrained(
183
+ "zhihan1996/DNA_bert_6", trust_remote_code=True
184
+ ).to(device)
185
+ self.device = device
186
 
187
  def embed(self, seqs):
188
  embs = []
189
  for s in seqs:
190
+ tokens = self.tokenizer(s, return_tensors="pt", padding=True)[
191
+ "input_ids"
192
+ ].to(self.device)
193
  with torch.no_grad():
194
  out = self.model(tokens).last_hidden_state.mean(1)
195
  embs.append(out.cpu().numpy())
196
  return np.vstack(embs)
197
 
198
+
199
  class NucleotideTransformerEmbedder:
200
  def __init__(self, device):
201
  # HF “feature-extraction” returns a list of (L, D) arrays for each input
 
203
  self.pipe = pipeline(
204
  "feature-extraction",
205
  model="InstaDeepAI/nucleotide-transformer-500m-1000g",
206
+ device=(
207
+ -1 if device == "cpu" else 0
208
+ ), # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0}
209
  )
210
 
211
  def embed(self, seqs):
 
215
  """
216
  all_embeddings = self.pipe(seqs, truncation=True, padding=True)
217
  # all_embeddings is a List of shape (L, D) arrays
218
+ pooled = [np.mean(x, axis=0) for x in all_embeddings]
219
+ return np.vstack(pooled)
220
+
221
 
222
  # class ESMEmbedder:
223
  # def __init__(self, device):
 
250
  self.batch_converter = self.alphabet.get_batch_converter()
251
  self.model.to(device).eval()
252
  # determine max length: esm2 models vary; use default 1024 for esm1b
253
+ self.max_len = (
254
+ 4096 if self.is_esm2 else 1024
255
+ ) # adjust if your esm2 variant has explicit limit
256
  # for chunking: reserve 2 tokens if model uses BOS/EOS
257
  self.chunk_size = self.max_len - 2
258
  self.overlap = self.chunk_size // 4 # 25% overlap to smooth boundaries
 
307
 
308
  # class ESMDBPEmbedder:
309
  # def __init__(self, device):
310
+ # base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
311
  # model_path = (
312
  # Path(__file__).resolve().parent.parent
313
  # / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
 
337
  # # skip start/end tokens
338
  # return reps[:, 1:-1].mean(1).cpu().numpy()
339
 
340
+
341
  class ESMDBPEmbedder:
342
  def __init__(self, device):
343
  base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
344
  model_path = (
345
  Path(__file__).resolve().parent.parent
346
+ / "pretrained"
347
+ / "ESM-DBP"
348
+ / "ESM-DBP.model"
349
  )
350
  checkpoint = torch.load(model_path, map_location="cpu")
351
  clean_sd = {}
 
402
  all_embeddings.append(seq_vec.cpu().numpy())
403
  return np.vstack(all_embeddings)
404
 
405
+
406
  class GPNEmbedder:
407
  def __init__(self, device):
408
  model_name = "songlab/gpn-msa-sapiens"
 
414
 
415
  def embed(self, seqs):
416
  inputs = self.tokenizer(
417
+ seqs, return_tensors="pt", padding=True, truncation=True
 
 
 
418
  ).to(self.device)
419
 
420
  with torch.no_grad():
421
  last_hidden = self.model(**inputs).last_hidden_state
422
  return last_hidden.mean(dim=1).cpu().numpy()
423
 
424
+
425
  class ProGenEmbedder:
426
  def __init__(self, device):
427
  model_name = "jinyuan22/ProGen2-base"
 
431
 
432
  def embed(self, seqs):
433
  inputs = self.tokenizer(
434
+ seqs, return_tensors="pt", padding=True, truncation=True
 
 
 
435
  ).to(self.device)
436
  with torch.no_grad():
437
  last_hidden = self.model(**inputs).last_hidden_state
438
  return last_hidden.mean(dim=1).cpu().numpy()
439
 
440
+
441
  # ---- main pipeline ----
442
 
443
+
444
  def get_embedder(name, device, for_dna=True):
445
  name = name.lower()
446
  if for_dna:
447
+ if name == "caduceus":
448
+ return CaduceusEmbedder(device)
449
+ if name == "dnabert":
450
+ return DNABertEmbedder(device)
451
+ if name == "nucleotide":
452
+ return NucleotideTransformerEmbedder(device)
453
+ if name == "gpn":
454
+ return GPNEmbedder(device)
455
+ if name == "segmentnt":
456
+ return SegmentNTEmbedder(device)
457
  else:
458
+ if name in ("esm",):
459
+ return ESMEmbedder(device)
460
+ if name in ("esm-dbp", "esm_dbp"):
461
+ return ESMDBPEmbedder(device)
462
+ if name == "progen":
463
+ return ProGenEmbedder(device)
464
  raise ValueError(f"Unknown model {name} (for_dna={for_dna})")
465
 
466
 
 
482
  mask[i, :L] = True
483
  return padded, mask
484
 
485
+
486
  def embed_and_save(seqs, ids, embedder, out_path):
487
  embs = embedder.embed(seqs)
488
 
489
  # Decide whether we got variable-length per-token outputs (list of (L, D))
490
+ is_variable_token = (
491
+ isinstance(embs, (list, tuple))
492
+ and len(embs) > 0
493
+ and hasattr(embs[0], "shape")
494
+ and embs[0].ndim == 2
495
+ )
496
 
497
  if is_variable_token:
498
  # pad to (N, L_max, D) + mask
499
  padded, mask = pad_token_embeddings(embs)
500
  # Save both embeddings and mask together in an .npz for convenience
501
+ np.savez_compressed(
502
+ out_path.with_suffix(".caduceus.npz"),
503
+ embeddings=padded,
504
+ mask=mask,
505
+ ids=np.array(ids, dtype=object),
506
+ )
507
  else:
508
  # fixed shape output, e.g., pooled (N, D)
509
  array = np.vstack(embs) if isinstance(embs, list) else embs
 
512
  f.write("\n".join(ids))
513
 
514
 
515
+ if __name__ == "__main__":
516
 
517
  p = argparse.ArgumentParser()
518
+ p.add_argument(
519
+ "--peak-fasta",
520
+ default="binding_peaks_unique.fa",
521
+ help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs",
522
+ )
523
+ p.add_argument(
524
+ "--genome-json-dir",
525
+ default=None,
526
+ help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes",
527
+ )
528
+ p.add_argument(
529
+ "--skip-dna",
530
+ action="store_true",
531
+ help="if set, skip the chromosome embedding step",
532
+ ) # if glm embeddings successful but not plm embeddings
533
+ p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
534
+ p.add_argument("--chrom-model", default="caduceus")
535
+ p.add_argument("--tf-model", default="esm-dbp")
536
+ p.add_argument(
537
+ "--out-dir", default="data_files/processed/tfclust/hg38_tf/embeddings"
538
+ )
539
+ p.add_argument("--device", default="cpu")
540
  args = p.parse_args()
541
 
542
  os.makedirs(args.out_dir, exist_ok=True)
 
553
  for rec in SeqIO.parse(peak_fasta, "fasta"):
554
  peak_ids.append(rec.id)
555
  peak_seqs.append(str(rec.seq))
556
+ print(
557
+ f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}",
558
+ flush=True,
559
+ )
560
  dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
561
  out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
562
  embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
 
564
  # Legacy: load full chromosomes from JSONs (chr1–22, X, Y, M)
565
  genome_dir = Path(args.genome_json_dir)
566
  chrom_seqs, chrom_ids = [], []
567
+ primary_pattern = re.compile(
568
+ r"^hg38_chr(?:[1-9]|1[0-9]|2[0-2]|X|Y|M)\.json$"
569
+ )
570
  for j in sorted(genome_dir.iterdir()):
571
  if not primary_pattern.match(j.name):
572
  continue
 
582
  if len(seq) > cutoff
583
  ]
584
  if long_chroms:
585
+ print(
586
+ "⚠️ Chromosomes exceeding Caduceus max tokens ({}):".format(cutoff)
587
+ )
588
  for chrom, L in long_chroms:
589
  print(f" {chrom}: {L} bases")
590
  else:
 
594
  out_chrom = Path(args.out_dir) / f"chrom_{args.chrom_model}.npy"
595
  embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom)
596
  else:
597
+ raise ValueError(
598
+ "No input for DNA embedding: provide a peak FASTA (default binding_peaks_unique.fa) or set --genome-json-dir for chromosome JSONs."
599
+ )
600
 
601
+ # Load TF sequences
602
  tf_seqs, tf_ids = [], []
603
  for record in SeqIO.parse(args.tf_fasta, "fasta"):
604
  tf_ids.append(record.id)
 
609
  out_tf = Path(args.out_dir) / f"tf_{args.tf_model}.npy"
610
  embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf)
611
 
612
+ print("Done.")