training works
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +6 -1
- configs/callbacks/default.yaml +21 -0
- configs/callbacks/early_stopping.yaml +15 -0
- configs/callbacks/model_checkpoint.yaml +17 -0
- configs/callbacks/model_summary.yaml +5 -0
- dpacman/classifier/model/__init__.py → configs/callbacks/none.yaml +0 -0
- configs/callbacks/rich_progress_bar.yaml +4 -0
- configs/data_module/pair.yaml +13 -0
- configs/data_modules/pair.yaml +0 -9
- configs/data_task/cluster/remap.yaml +1 -1
- configs/data_task/download/genome.yaml +1 -1
- configs/data_task/download/remap.yaml +1 -1
- configs/data_task/embeddings/dna.yaml +8 -4
- configs/data_task/embeddings/protein.yaml +14 -0
- configs/data_task/fimo/post_fimo.yaml +1 -1
- configs/data_task/fimo/pre_fimo.yaml +1 -1
- configs/data_task/fimo/run_fimo.yaml +1 -1
- configs/data_task/split/remap.yaml +3 -0
- configs/extras/default.yaml +8 -0
- configs/logger/aim.yaml +28 -0
- configs/logger/comet.yaml +12 -0
- configs/logger/csv.yaml +7 -0
- configs/logger/many_loggers.yaml +9 -0
- configs/logger/mlflow.yaml +12 -0
- configs/logger/neptune.yaml +9 -0
- configs/logger/tensorboard.yaml +10 -0
- configs/logger/wandb.yaml +16 -0
- configs/model/classifier.yaml +9 -0
- configs/{models → model}/pooling/truncatedsvd.yaml +0 -0
- configs/models/classifier.yaml +0 -11
- configs/preprocess.yaml +1 -1
- configs/train.yaml +37 -2
- configs/trainer/cpu.yaml +5 -0
- configs/trainer/ddp.yaml +9 -0
- configs/trainer/ddp_sim.yaml +7 -0
- configs/trainer/default.yaml +19 -0
- configs/trainer/gpu.yaml +5 -0
- configs/trainer/mps.yaml +5 -0
- dpacman/classifier/loss.py +58 -0
- dpacman/classifier/model.py +258 -0
- dpacman/classifier/model/clustering_data.py +0 -383
- dpacman/classifier/model/compress_embeddings.py +0 -54
- dpacman/classifier/model/compute_embeddings.py +0 -560
- dpacman/classifier/model/extract_tf_symbols.py +0 -27
- dpacman/classifier/model/loss.py +0 -34
- dpacman/classifier/model/make_pair_list.py +0 -220
- dpacman/classifier/model/make_peak_fasta.py +0 -13
- dpacman/classifier/model_tmp/clustering_data.py +139 -47
- dpacman/classifier/model_tmp/compress_embeddings.py +15 -7
- dpacman/classifier/model_tmp/compute_embeddings.py +125 -59
.gitignore
CHANGED
|
@@ -29,4 +29,9 @@ dpacman/nohup.out
|
|
| 29 |
dpacman/*/__pycache__/
|
| 30 |
dpacman/data_tasks/split/__pycache__/
|
| 31 |
dpacman/data_tasks/cluster/__pycache__/
|
| 32 |
-
dpacman/data_tasks/embeddings/__pycache__/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
dpacman/*/__pycache__/
|
| 30 |
dpacman/data_tasks/split/__pycache__/
|
| 31 |
dpacman/data_tasks/cluster/__pycache__/
|
| 32 |
+
dpacman/data_tasks/embeddings/__pycache__/
|
| 33 |
+
dpacman/combine_shards.py
|
| 34 |
+
dpacman/combine.log
|
| 35 |
+
dpacman/loss_sim.py
|
| 36 |
+
dpacman/loss_temp.py
|
| 37 |
+
dpacman/peak_examples/
|
configs/callbacks/default.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- model_checkpoint
|
| 3 |
+
- early_stopping
|
| 4 |
+
- model_summary
|
| 5 |
+
- _self_
|
| 6 |
+
|
| 7 |
+
model_checkpoint:
|
| 8 |
+
dirpath: ${paths.output_dir}/checkpoints
|
| 9 |
+
filename: "epoch_{epoch:03d}"
|
| 10 |
+
monitor: "val/loss"
|
| 11 |
+
mode: "min"
|
| 12 |
+
save_last: True
|
| 13 |
+
auto_insert_metric_name: False
|
| 14 |
+
|
| 15 |
+
early_stopping:
|
| 16 |
+
monitor: "val/loss"
|
| 17 |
+
patience: 100
|
| 18 |
+
mode: "min"
|
| 19 |
+
|
| 20 |
+
model_summary:
|
| 21 |
+
max_depth: -1
|
configs/callbacks/early_stopping.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
|
| 2 |
+
|
| 3 |
+
early_stopping:
|
| 4 |
+
_target_: lightning.pytorch.callbacks.EarlyStopping
|
| 5 |
+
monitor: ??? # quantity to be monitored, must be specified !!!
|
| 6 |
+
min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
|
| 7 |
+
patience: 3 # number of checks with no improvement after which training will be stopped
|
| 8 |
+
verbose: False # verbosity mode
|
| 9 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 10 |
+
strict: True # whether to crash the training if monitor is not found in the validation metrics
|
| 11 |
+
check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
|
| 12 |
+
stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
|
| 13 |
+
divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
|
| 14 |
+
check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
|
| 15 |
+
# log_rank_zero_only: False # this keyword argument isn't available in stable version
|
configs/callbacks/model_checkpoint.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
|
| 2 |
+
|
| 3 |
+
model_checkpoint:
|
| 4 |
+
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
| 5 |
+
dirpath: null # directory to save the model file
|
| 6 |
+
filename: null # checkpoint filename
|
| 7 |
+
monitor: null # name of the logged metric which determines when model is improving
|
| 8 |
+
verbose: False # verbosity mode
|
| 9 |
+
save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
|
| 10 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
| 11 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 12 |
+
auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
|
| 13 |
+
save_weights_only: False # if True, then only the model’s weights will be saved
|
| 14 |
+
every_n_train_steps: null # number of training steps between checkpoints
|
| 15 |
+
train_time_interval: null # checkpoints are monitored at the specified time interval
|
| 16 |
+
every_n_epochs: null # number of epochs between checkpoints
|
| 17 |
+
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
|
configs/callbacks/model_summary.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
|
| 2 |
+
|
| 3 |
+
model_summary:
|
| 4 |
+
_target_: lightning.pytorch.callbacks.RichModelSummary
|
| 5 |
+
max_depth: 1 # the maximum depth of layer nesting that the summary will include
|
dpacman/classifier/model/__init__.py → configs/callbacks/none.yaml
RENAMED
|
File without changes
|
configs/callbacks/rich_progress_bar.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
|
| 2 |
+
|
| 3 |
+
rich_progress_bar:
|
| 4 |
+
_target_: lightning.pytorch.callbacks.RichProgressBar
|
configs/data_module/pair.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: dpacman.data_modules.pair.PairDataModule
|
| 2 |
+
|
| 3 |
+
train_file: data_files/processed/splits/by_dna/babytrain.csv
|
| 4 |
+
val_file: data_files/processed/splits/by_dna/babyval.csv
|
| 5 |
+
test_file: data_files/processed/splits/by_dna/babytest.csv
|
| 6 |
+
|
| 7 |
+
tr_shelf_path: data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf
|
| 8 |
+
dna_shelf_path: data_files/processed/embeddings/fimo_hits_only/baby_peaks_segmentnt_pernuc_with_onehot.shelf
|
| 9 |
+
|
| 10 |
+
batch_size: 32
|
| 11 |
+
num_workers: 8
|
| 12 |
+
|
| 13 |
+
maximize_num_workers: False
|
configs/data_modules/pair.yaml
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
train_file: data_files/splits/train.csv
|
| 3 |
-
val_file: data_files/splits/val.csv
|
| 4 |
-
test_file: data_files/splits/test.csv
|
| 5 |
-
|
| 6 |
-
batch_size: 32
|
| 7 |
-
num_workers: 8
|
| 8 |
-
|
| 9 |
-
maximize_num_workers: False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/data_task/cluster/remap.yaml
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
name: remap
|
| 2 |
-
|
| 3 |
|
| 4 |
max_protein_length: 1998
|
| 5 |
|
|
|
|
| 1 |
name: remap
|
| 2 |
+
task_type: cluster
|
| 3 |
|
| 4 |
max_protein_length: 1998
|
| 5 |
|
configs/data_task/download/genome.yaml
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
name: genome
|
| 2 |
-
|
| 3 |
output_dir: dpacman/data_files/raw/genomes
|
| 4 |
genomes:
|
| 5 |
- hg38
|
|
|
|
| 1 |
name: genome
|
| 2 |
+
task_type: download
|
| 3 |
output_dir: dpacman/data_files/raw/genomes
|
| 4 |
genomes:
|
| 5 |
- hg38
|
configs/data_task/download/remap.yaml
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
name: remap
|
| 2 |
-
|
| 3 |
|
| 4 |
nr_url: https://remap.univ-amu.fr/storage/remap2022/hg38/MACS2/remap2022_nr_macs2_hg38_v1_0.bed.gz
|
| 5 |
nr_output_dir: dpacman/data_files/raw/remap
|
|
|
|
| 1 |
name: remap
|
| 2 |
+
task_type: download
|
| 3 |
|
| 4 |
nr_url: https://remap.univ-amu.fr/storage/remap2022/hg38/MACS2/remap2022_nr_macs2_hg38_v1_0.bed.gz
|
| 5 |
nr_output_dir: dpacman/data_files/raw/remap
|
configs/data_task/embeddings/dna.yaml
CHANGED
|
@@ -1,9 +1,13 @@
|
|
| 1 |
name: dna
|
| 2 |
-
|
| 3 |
|
| 4 |
genome_json_dir: null
|
| 5 |
-
chrom_model:
|
| 6 |
-
input_file: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/
|
| 7 |
out_dir: dpacman/data_files/processed/embeddings/fimo_hits_only
|
| 8 |
|
| 9 |
-
device: gpu
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
name: dna
|
| 2 |
+
task_type: embeddings
|
| 3 |
|
| 4 |
genome_json_dir: null
|
| 5 |
+
chrom_model: segmentnt
|
| 6 |
+
input_file: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/dna_seqid_to_dna_sequence_with_rc.json
|
| 7 |
out_dir: dpacman/data_files/processed/embeddings/fimo_hits_only
|
| 8 |
|
| 9 |
+
device: gpu
|
| 10 |
+
|
| 11 |
+
batch_size: 1
|
| 12 |
+
|
| 13 |
+
debug: false
|
configs/data_task/embeddings/protein.yaml
CHANGED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: protein
|
| 2 |
+
task_type: embeddings
|
| 3 |
+
|
| 4 |
+
prot_model: esm
|
| 5 |
+
input_file: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/tr_seqid_to_tr_sequence.json
|
| 6 |
+
out_dir: dpacman/data_files/processed/embeddings/fimo_hits_only
|
| 7 |
+
|
| 8 |
+
device: gpu
|
| 9 |
+
|
| 10 |
+
save_as_shelf: true
|
| 11 |
+
|
| 12 |
+
batch_size: 1
|
| 13 |
+
|
| 14 |
+
debug: false
|
configs/data_task/fimo/post_fimo.yaml
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
name: post_fimo
|
| 2 |
-
|
| 3 |
|
| 4 |
fimo_out_dir: dpacman/data_files/processed/fimo/fimo_out_q
|
| 5 |
processed_output_csv: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/remap2022_crm_fimo_output_q_processed.csv
|
|
|
|
| 1 |
name: post_fimo
|
| 2 |
+
task_type: fimo
|
| 3 |
|
| 4 |
fimo_out_dir: dpacman/data_files/processed/fimo/fimo_out_q
|
| 5 |
processed_output_csv: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/remap2022_crm_fimo_output_q_processed.csv
|
configs/data_task/fimo/pre_fimo.yaml
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
name: pre_fimo
|
| 2 |
-
|
| 3 |
|
| 4 |
paths:
|
| 5 |
input_csv: dpacman/data_files/processed/remap/remap2022_crm_macs2_hg38_v1_0_clean.tsv
|
|
|
|
| 1 |
name: pre_fimo
|
| 2 |
+
task_type: fimo
|
| 3 |
|
| 4 |
paths:
|
| 5 |
input_csv: dpacman/data_files/processed/remap/remap2022_crm_macs2_hg38_v1_0_clean.tsv
|
configs/data_task/fimo/run_fimo.yaml
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
name: run_fimo
|
| 2 |
-
|
| 3 |
|
| 4 |
debug: true
|
| 5 |
|
|
|
|
| 1 |
name: run_fimo
|
| 2 |
+
task_type: fimo
|
| 3 |
|
| 4 |
debug: true
|
| 5 |
|
configs/data_task/split/remap.yaml
CHANGED
|
@@ -10,7 +10,10 @@ cluster_output_paths:
|
|
| 10 |
input_data_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/remap2022_crm_fimo_output_q_processed_seed0.parquet
|
| 11 |
split_out_dir: dpacman/data_files/processed/splits
|
| 12 |
|
|
|
|
|
|
|
| 13 |
split_by: both # protein, dna, or both
|
|
|
|
| 14 |
|
| 15 |
test_ratio: 0.10
|
| 16 |
val_ratio: 0.10
|
|
|
|
| 10 |
input_data_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/remap2022_crm_fimo_output_q_processed_seed0.parquet
|
| 11 |
split_out_dir: dpacman/data_files/processed/splits
|
| 12 |
|
| 13 |
+
dna_map_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/dna_seqid_to_dna_sequence.json
|
| 14 |
+
|
| 15 |
split_by: both # protein, dna, or both
|
| 16 |
+
augment_rc: true
|
| 17 |
|
| 18 |
test_ratio: 0.10
|
| 19 |
val_ratio: 0.10
|
configs/extras/default.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# disable python warnings if they annoy you
|
| 2 |
+
ignore_warnings: False
|
| 3 |
+
|
| 4 |
+
# ask user for tags if none are provided in the config
|
| 5 |
+
enforce_tags: True
|
| 6 |
+
|
| 7 |
+
# pretty print config tree at the start of the run using Rich library
|
| 8 |
+
print_config: True
|
configs/logger/aim.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://aimstack.io/
|
| 2 |
+
|
| 3 |
+
# example usage in lightning module:
|
| 4 |
+
# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
|
| 5 |
+
|
| 6 |
+
# open the Aim UI with the following command (run in the folder containing the `.aim` folder):
|
| 7 |
+
# `aim up`
|
| 8 |
+
|
| 9 |
+
aim:
|
| 10 |
+
_target_: aim.pytorch_lightning.AimLogger
|
| 11 |
+
repo: ${paths.root_dir} # .aim folder will be created here
|
| 12 |
+
# repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
|
| 13 |
+
|
| 14 |
+
# aim allows to group runs under experiment name
|
| 15 |
+
experiment: null # any string, set to "default" if not specified
|
| 16 |
+
|
| 17 |
+
train_metric_prefix: "train/"
|
| 18 |
+
val_metric_prefix: "val/"
|
| 19 |
+
test_metric_prefix: "test/"
|
| 20 |
+
|
| 21 |
+
# sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
|
| 22 |
+
system_tracking_interval: 10 # set to null to disable system metrics tracking
|
| 23 |
+
|
| 24 |
+
# enable/disable logging of system params such as installed packages, git info, env vars, etc.
|
| 25 |
+
log_system_params: true
|
| 26 |
+
|
| 27 |
+
# enable/disable tracking console logs (default value is true)
|
| 28 |
+
capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550
|
configs/logger/comet.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://www.comet.ml
|
| 2 |
+
|
| 3 |
+
comet:
|
| 4 |
+
_target_: lightning.pytorch.loggers.comet.CometLogger
|
| 5 |
+
api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
|
| 6 |
+
save_dir: "${paths.output_dir}"
|
| 7 |
+
project_name: "lightning-hydra-template"
|
| 8 |
+
rest_api_key: null
|
| 9 |
+
# experiment_name: ""
|
| 10 |
+
experiment_key: null # set to resume experiment
|
| 11 |
+
offline: False
|
| 12 |
+
prefix: ""
|
configs/logger/csv.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# csv logger built in lightning
|
| 2 |
+
|
| 3 |
+
csv:
|
| 4 |
+
_target_: lightning.pytorch.loggers.csv_logs.CSVLogger
|
| 5 |
+
save_dir: "${paths.output_dir}"
|
| 6 |
+
name: "csv/"
|
| 7 |
+
prefix: ""
|
configs/logger/many_loggers.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# train with many loggers at once
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
# - comet
|
| 5 |
+
- csv
|
| 6 |
+
# - mlflow
|
| 7 |
+
# - neptune
|
| 8 |
+
- tensorboard
|
| 9 |
+
- wandb
|
configs/logger/mlflow.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://mlflow.org
|
| 2 |
+
|
| 3 |
+
mlflow:
|
| 4 |
+
_target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
|
| 5 |
+
# experiment_name: ""
|
| 6 |
+
# run_name: ""
|
| 7 |
+
tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
|
| 8 |
+
tags: null
|
| 9 |
+
# save_dir: "./mlruns"
|
| 10 |
+
prefix: ""
|
| 11 |
+
artifact_location: null
|
| 12 |
+
# run_id: ""
|
configs/logger/neptune.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://neptune.ai
|
| 2 |
+
|
| 3 |
+
neptune:
|
| 4 |
+
_target_: lightning.pytorch.loggers.neptune.NeptuneLogger
|
| 5 |
+
api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
|
| 6 |
+
project: username/lightning-hydra-template
|
| 7 |
+
# name: ""
|
| 8 |
+
log_model_checkpoints: True
|
| 9 |
+
prefix: ""
|
configs/logger/tensorboard.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://www.tensorflow.org/tensorboard/
|
| 2 |
+
|
| 3 |
+
tensorboard:
|
| 4 |
+
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
|
| 5 |
+
save_dir: "${paths.output_dir}/tensorboard/"
|
| 6 |
+
name: null
|
| 7 |
+
log_graph: False
|
| 8 |
+
default_hp_metric: True
|
| 9 |
+
prefix: ""
|
| 10 |
+
# version: ""
|
configs/logger/wandb.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://wandb.ai
|
| 2 |
+
|
| 3 |
+
wandb:
|
| 4 |
+
_target_: lightning.pytorch.loggers.wandb.WandbLogger
|
| 5 |
+
# name: "" # name of the run (normally generated by wandb)
|
| 6 |
+
save_dir: "${paths.output_dir}"
|
| 7 |
+
offline: False
|
| 8 |
+
id: null # pass correct id to resume experiment!
|
| 9 |
+
anonymous: null # enable anonymous logging
|
| 10 |
+
project: "dnabind"
|
| 11 |
+
log_model: False # upload lightning ckpts
|
| 12 |
+
prefix: "" # a string to put at the beginning of metric keys
|
| 13 |
+
# entity: "" # set to name of your wandb team
|
| 14 |
+
group: ""
|
| 15 |
+
tags: []
|
| 16 |
+
job_type: ""
|
configs/model/classifier.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: dpacman.classifier.model.BindPredictor
|
| 2 |
+
|
| 3 |
+
lr: 1e-4
|
| 4 |
+
alpha: 20
|
| 5 |
+
gamma: 20
|
| 6 |
+
weight_decay: 0.01
|
| 7 |
+
|
| 8 |
+
glm_input_dim: 1029
|
| 9 |
+
compressed_dim: 1029
|
configs/{models → model}/pooling/truncatedsvd.yaml
RENAMED
|
File without changes
|
configs/models/classifier.yaml
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
name: classifier
|
| 2 |
-
type: train
|
| 3 |
-
|
| 4 |
-
params:
|
| 5 |
-
epochs: 10
|
| 6 |
-
batch_size: 32
|
| 7 |
-
lr: 1e-4
|
| 8 |
-
seed: 42
|
| 9 |
-
|
| 10 |
-
out_dir: null
|
| 11 |
-
pair_list: null
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/preprocess.yaml
CHANGED
|
@@ -6,4 +6,4 @@ defaults:
|
|
| 6 |
- hydra: default # ← tells Hydra to use the logging/output config
|
| 7 |
- data_task: download/genome
|
| 8 |
|
| 9 |
-
task_name: preprocess/${data_task.
|
|
|
|
| 6 |
- hydra: default # ← tells Hydra to use the logging/output config
|
| 7 |
- data_task: download/genome
|
| 8 |
|
| 9 |
+
task_name: preprocess/${data_task.task_type}
|
configs/train.yaml
CHANGED
|
@@ -2,7 +2,42 @@ defaults:
|
|
| 2 |
- _self_
|
| 3 |
- paths: default
|
| 4 |
- hydra: default # ← tells Hydra to use the logging/output config
|
|
|
|
|
|
|
| 5 |
- trainer: gpu
|
| 6 |
-
-
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
- _self_
|
| 3 |
- paths: default
|
| 4 |
- hydra: default # ← tells Hydra to use the logging/output config
|
| 5 |
+
- data_module: pair
|
| 6 |
+
- model: classifier
|
| 7 |
- trainer: gpu
|
| 8 |
+
- extras: default
|
| 9 |
+
- logger: wandb
|
| 10 |
+
- callbacks: default
|
| 11 |
|
| 12 |
+
# experiment configs allow for version control of specific hyperparameters
|
| 13 |
+
# e.g. best hyperparameters for given model and datamodule
|
| 14 |
+
- experiment: null
|
| 15 |
+
|
| 16 |
+
# config for hyperparameter optimization
|
| 17 |
+
- hparams_search: null
|
| 18 |
+
|
| 19 |
+
# debugging config (enable through command line, e.g. `python train.py debug=default)
|
| 20 |
+
- debug: null
|
| 21 |
+
|
| 22 |
+
task_name: train/${model}
|
| 23 |
+
|
| 24 |
+
# tags to help you identify your experiments
|
| 25 |
+
# you can overwrite this in experiment configs
|
| 26 |
+
# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
|
| 27 |
+
tags: ["dev"]
|
| 28 |
+
|
| 29 |
+
# set False to skip model training
|
| 30 |
+
train: True
|
| 31 |
+
|
| 32 |
+
# evaluate on test set, using best model weights achieved during training
|
| 33 |
+
# lightning chooses best weights based on the metric specified in checkpoint callback
|
| 34 |
+
test: True
|
| 35 |
+
|
| 36 |
+
# simply provide checkpoint path to resume training
|
| 37 |
+
ckpt_path: null
|
| 38 |
+
|
| 39 |
+
# seed for random number generators in pytorch, numpy and python.random
|
| 40 |
+
seed: 42
|
| 41 |
+
|
| 42 |
+
trainer:
|
| 43 |
+
max_epochs: 20
|
configs/trainer/cpu.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default
|
| 3 |
+
|
| 4 |
+
accelerator: cpu
|
| 5 |
+
devices: 1
|
configs/trainer/ddp.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default
|
| 3 |
+
|
| 4 |
+
strategy: ddp
|
| 5 |
+
|
| 6 |
+
accelerator: gpu
|
| 7 |
+
devices: 4
|
| 8 |
+
num_nodes: 1
|
| 9 |
+
sync_batchnorm: True
|
configs/trainer/ddp_sim.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default
|
| 3 |
+
|
| 4 |
+
# simulate DDP on CPU, useful for debugging
|
| 5 |
+
accelerator: cpu
|
| 6 |
+
devices: 2
|
| 7 |
+
strategy: ddp_spawn
|
configs/trainer/default.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: lightning.pytorch.trainer.Trainer
|
| 2 |
+
|
| 3 |
+
default_root_dir: ${paths.output_dir}
|
| 4 |
+
|
| 5 |
+
min_epochs: 1 # prevents early stopping
|
| 6 |
+
max_epochs: 10
|
| 7 |
+
|
| 8 |
+
accelerator: cpu
|
| 9 |
+
devices: 1
|
| 10 |
+
|
| 11 |
+
# mixed precision for extra speed-up
|
| 12 |
+
# precision: 16
|
| 13 |
+
|
| 14 |
+
# perform a validation loop every N training epochs
|
| 15 |
+
check_val_every_n_epoch: 1
|
| 16 |
+
|
| 17 |
+
# set True to to ensure deterministic results
|
| 18 |
+
# makes training slower but gives more reproducibility than just setting seeds
|
| 19 |
+
deterministic: False
|
configs/trainer/gpu.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default
|
| 3 |
+
|
| 4 |
+
accelerator: gpu
|
| 5 |
+
devices: 1
|
configs/trainer/mps.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default
|
| 3 |
+
|
| 4 |
+
accelerator: mps
|
| 5 |
+
devices: 1
|
dpacman/classifier/loss.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Define loss functions needed for training the model
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def bce_loss_masked(logits, targets, nonpeak_mask, pos_weight=None):
|
| 10 |
+
"""
|
| 11 |
+
Compute the masked Binary Cross Entropy, only on certain positions.
|
| 12 |
+
We will only compute BCE on positions whre nonpeak_mask == 1.0; the mask represents non-peak positions
|
| 13 |
+
"""
|
| 14 |
+
loss = F.binary_cross_entropy_with_logits(
|
| 15 |
+
logits, targets, reduction="none", pos_weight=pos_weight
|
| 16 |
+
)
|
| 17 |
+
denom = nonpeak_mask.sum().clamp_min(1.0)
|
| 18 |
+
return (loss * nonpeak_mask).sum() / denom
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def mse_peaks_only(logits, targets, peak_mask, eps=1e-8):
|
| 22 |
+
"""
|
| 23 |
+
Calculate MSE on peaks only.
|
| 24 |
+
"""
|
| 25 |
+
probs = torch.sigmoid(logits)
|
| 26 |
+
mse_peaks = F.mse_loss(probs * peak_mask, targets * peak_mask, reduction="sum") / (
|
| 27 |
+
peak_mask.sum() + eps
|
| 28 |
+
)
|
| 29 |
+
return mse_peaks
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def calculate_loss(logits, targets, eps=1e-8, alpha=1.0, gamma=1.0):
|
| 33 |
+
"""
|
| 34 |
+
Combine masked-BCE + global-MSE to get a loss vlaue
|
| 35 |
+
"""
|
| 36 |
+
# Calculate peak and non-peak masks.
|
| 37 |
+
# Anything outside a peak will have a label equal to 0.
|
| 38 |
+
nonpeak_mask = (targets == 0).float()
|
| 39 |
+
peak_mask = (targets > 0).float()
|
| 40 |
+
|
| 41 |
+
bce_nonpeak = bce_loss_masked(logits, targets, nonpeak_mask)
|
| 42 |
+
mse_peak = mse_peaks_only(logits, targets, peak_mask, eps=eps)
|
| 43 |
+
|
| 44 |
+
loss = alpha * bce_nonpeak + gamma * mse_peak
|
| 45 |
+
|
| 46 |
+
return loss
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def accuracy_percentage(logits, targets, peak_thresh=0.5):
|
| 50 |
+
"""
|
| 51 |
+
Compute accuracy in predicting high-confidence peaks (probability > 0.5)
|
| 52 |
+
"""
|
| 53 |
+
probs = torch.sigmoid(logits)
|
| 54 |
+
preds_bin = (probs >= 0.5).float()
|
| 55 |
+
labels = (targets >= peak_thresh).float()
|
| 56 |
+
correct = (preds_bin == labels).float().sum()
|
| 57 |
+
total = torch.numel(labels)
|
| 58 |
+
return (correct / max(1, total)).item() * 100.0
|
dpacman/classifier/model.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Lightning Module for the binding model.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from lightning import LightningModule
|
| 8 |
+
from dpacman.utils.models import set_seed
|
| 9 |
+
from .loss import calculate_loss
|
| 10 |
+
|
| 11 |
+
set_seed()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LocalCNN(nn.Module):
|
| 15 |
+
def __init__(self, dim: int = 256, kernel_size: int = 3):
|
| 16 |
+
super().__init__()
|
| 17 |
+
padding = kernel_size // 2
|
| 18 |
+
self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
|
| 19 |
+
self.act = nn.GELU()
|
| 20 |
+
self.ln = nn.LayerNorm(dim)
|
| 21 |
+
|
| 22 |
+
def forward(self, x: torch.Tensor):
|
| 23 |
+
# x: (batch, L, dim)
|
| 24 |
+
out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
|
| 25 |
+
out = self.act(out)
|
| 26 |
+
out = out.transpose(1, 2) # → (batch, L, dim)
|
| 27 |
+
return self.ln(out + x) # residual
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class CrossModalBlock(nn.Module):
|
| 31 |
+
def __init__(self, dim: int = 256, heads: int = 8):
|
| 32 |
+
super().__init__()
|
| 33 |
+
# self-attention for both sides
|
| 34 |
+
self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 35 |
+
self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 36 |
+
self.ln_b1 = nn.LayerNorm(dim)
|
| 37 |
+
self.ln_g1 = nn.LayerNorm(dim)
|
| 38 |
+
|
| 39 |
+
self.ffn_b = nn.Sequential(
|
| 40 |
+
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 41 |
+
)
|
| 42 |
+
self.ffn_g = nn.Sequential(
|
| 43 |
+
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 44 |
+
)
|
| 45 |
+
self.ln_b2 = nn.LayerNorm(dim)
|
| 46 |
+
self.ln_g2 = nn.LayerNorm(dim)
|
| 47 |
+
|
| 48 |
+
# cross attention (binder queries, glm keys/values)
|
| 49 |
+
# so the NDA path is updated by the transcriptoin factors
|
| 50 |
+
self.cross_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 51 |
+
self.ln_c1 = nn.LayerNorm(dim)
|
| 52 |
+
self.ffn_c = nn.Sequential(
|
| 53 |
+
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 54 |
+
)
|
| 55 |
+
self.ln_c2 = nn.LayerNorm(dim)
|
| 56 |
+
|
| 57 |
+
def forward(self, binder: torch.Tensor, glm: torch.Tensor):
|
| 58 |
+
"""
|
| 59 |
+
binder: (batch, Lb, dim)
|
| 60 |
+
glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
|
| 61 |
+
returns: updated binder representation (batch, Lb, dim)
|
| 62 |
+
"""
|
| 63 |
+
# binder: self-attn + ffn
|
| 64 |
+
b = binder
|
| 65 |
+
b_sa, _ = self.sa_binder(b, b, b)
|
| 66 |
+
b = self.ln_b1(b + b_sa)
|
| 67 |
+
b_ff = self.ffn_b(b)
|
| 68 |
+
b = self.ln_b2(b + b_ff)
|
| 69 |
+
|
| 70 |
+
# glm: self-attn + ffn
|
| 71 |
+
g = glm
|
| 72 |
+
g_sa, _ = self.sa_glm(g, g, g)
|
| 73 |
+
g = self.ln_g1(g + g_sa)
|
| 74 |
+
g_ff = self.ffn_g(g)
|
| 75 |
+
g = self.ln_g2(g + g_ff)
|
| 76 |
+
|
| 77 |
+
# cross-attention: glm queries binder and glm embeddings are updated
|
| 78 |
+
g_to_b_ca, _ = self.cross_attn(g, b, b)
|
| 79 |
+
g = self.ln_c1(g + g_to_b_ca)
|
| 80 |
+
g_ff = self.ffn_c(g)
|
| 81 |
+
g = self.ln_c2(g + g_ff)
|
| 82 |
+
return g # (batch, Lb, dim)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class DimCompressor(nn.Module):
|
| 86 |
+
"""
|
| 87 |
+
Learnable per-token compressor: maps any in_dim >= out_dim to out_dim (default 256).
|
| 88 |
+
If in_dim == out_dim, behaves as identity.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self, in_dim: int, out_dim: int = 256):
|
| 92 |
+
super().__init__()
|
| 93 |
+
if in_dim == out_dim:
|
| 94 |
+
self.net = nn.Identity()
|
| 95 |
+
else:
|
| 96 |
+
hidden = max(out_dim * 2, (in_dim + out_dim) // 2)
|
| 97 |
+
self.net = nn.Sequential(
|
| 98 |
+
nn.LayerNorm(in_dim),
|
| 99 |
+
nn.Linear(in_dim, hidden),
|
| 100 |
+
nn.GELU(),
|
| 101 |
+
nn.Linear(hidden, out_dim),
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 105 |
+
# x: (B, L, in_dim)
|
| 106 |
+
return self.net(x)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class BindPredictor(LightningModule):
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
# input_dim: int = 256, # OLD: single input dim
|
| 113 |
+
binder_input_dim: int = 1280, # NEW: TF (binder) original dim (e.g., 1280)
|
| 114 |
+
glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256)
|
| 115 |
+
compressed_dim: int = 256, # NEW: learnable compressed dim
|
| 116 |
+
hidden_dim: int = 256,
|
| 117 |
+
heads: int = 8,
|
| 118 |
+
num_layers: int = 4,
|
| 119 |
+
lr: float = 1e-4,
|
| 120 |
+
alpha: float = 20,
|
| 121 |
+
gamma: float = 20,
|
| 122 |
+
use_local_cnn_on_glm: bool = True,
|
| 123 |
+
weight_decay: float = 0.01,
|
| 124 |
+
):
|
| 125 |
+
# Init
|
| 126 |
+
super(BindPredictor, self).__init__()
|
| 127 |
+
self.save_hyperparameters()
|
| 128 |
+
|
| 129 |
+
# Learnable compressor for binder -> 256, then project to hidden
|
| 130 |
+
self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
|
| 131 |
+
self.proj_binder = nn.Linear(compressed_dim, hidden_dim)
|
| 132 |
+
|
| 133 |
+
# GLM side stays 256 -> hidden
|
| 134 |
+
self.proj_glm = nn.Linear(glm_input_dim, hidden_dim)
|
| 135 |
+
|
| 136 |
+
self.use_local_cnn = use_local_cnn_on_glm
|
| 137 |
+
self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
|
| 138 |
+
|
| 139 |
+
self.layers = nn.ModuleList(
|
| 140 |
+
[CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)]
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
self.ln_out = nn.LayerNorm(hidden_dim)
|
| 144 |
+
# self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
|
| 145 |
+
self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
|
| 146 |
+
|
| 147 |
+
def forward(self, binder_emb, glm_emb):
|
| 148 |
+
"""
|
| 149 |
+
binder_emb: (B, Lb, binder_input_dim)
|
| 150 |
+
glm_emb: (B, Lg, glm_input_dim)
|
| 151 |
+
Returns per-nucleotide logits for the GLM sequence: (B, Lg)
|
| 152 |
+
"""
|
| 153 |
+
# Binder: learnable compression → 256 → hidden
|
| 154 |
+
b = self.binder_compress(binder_emb) # (B, Lb, 256)
|
| 155 |
+
b = self.proj_binder(b) # (B, Lb, hidden_dim)
|
| 156 |
+
|
| 157 |
+
# GLM: project → hidden, add local CNN context
|
| 158 |
+
g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
|
| 159 |
+
if self.use_local_cnn:
|
| 160 |
+
g = self.local_cnn(g)
|
| 161 |
+
|
| 162 |
+
# Cross-modal blocks: update binder states using GLM
|
| 163 |
+
for layer in self.layers:
|
| 164 |
+
g = layer(b, g) # (B, Lb, hidden_dim)
|
| 165 |
+
|
| 166 |
+
# Predict per-nucleotide logits on the GLM tokens:
|
| 167 |
+
# return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head)
|
| 168 |
+
return self.head(g).squeeze(
|
| 169 |
+
-1
|
| 170 |
+
) # NEW: logits (apply sigmoid only in loss/metrics)
|
| 171 |
+
|
| 172 |
+
# ----- Lightning hooks -----
|
| 173 |
+
def training_step(self, batch, batch_idx):
|
| 174 |
+
"""
|
| 175 |
+
Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator.
|
| 176 |
+
Colator returns a dictionary with:
|
| 177 |
+
"binder_emb" # [B, Lb_max, Db]
|
| 178 |
+
"binder_mask" # [B, Lb_max]
|
| 179 |
+
"glm_emb" # [B, Lg_max, Dg]
|
| 180 |
+
"glm_mask" # [B, Lg_max]
|
| 181 |
+
"labels" # [B, Lg_max]
|
| 182 |
+
"ID"
|
| 183 |
+
"tr_sequence"
|
| 184 |
+
"dna_sequence"
|
| 185 |
+
}
|
| 186 |
+
"""
|
| 187 |
+
logits = self.forward(batch["binder_emb"], batch["glm_emb"])
|
| 188 |
+
loss = calculate_loss(
|
| 189 |
+
logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
|
| 190 |
+
)
|
| 191 |
+
self.log(
|
| 192 |
+
"train/loss",
|
| 193 |
+
loss,
|
| 194 |
+
on_step=True,
|
| 195 |
+
on_epoch=True,
|
| 196 |
+
prog_bar=True,
|
| 197 |
+
batch_size=logits.size(0),
|
| 198 |
+
)
|
| 199 |
+
return loss
|
| 200 |
+
|
| 201 |
+
def validation_step(self, batch, batch_idx):
|
| 202 |
+
logits = self.forward(batch["binder_emb"], batch["glm_emb"])
|
| 203 |
+
loss = calculate_loss(
|
| 204 |
+
logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
|
| 205 |
+
)
|
| 206 |
+
self.log(
|
| 207 |
+
"val/loss",
|
| 208 |
+
loss,
|
| 209 |
+
on_step=False,
|
| 210 |
+
on_epoch=True,
|
| 211 |
+
prog_bar=True,
|
| 212 |
+
batch_size=logits.size(0),
|
| 213 |
+
)
|
| 214 |
+
return loss
|
| 215 |
+
|
| 216 |
+
def test_step(self, batch, batch_idx):
|
| 217 |
+
logits = self.forward(batch["binder_emb"], batch["glm_emb"])
|
| 218 |
+
loss = calculate_loss(
|
| 219 |
+
logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
|
| 220 |
+
)
|
| 221 |
+
self.log(
|
| 222 |
+
"test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
|
| 223 |
+
)
|
| 224 |
+
return loss
|
| 225 |
+
|
| 226 |
+
def on_train_epoch_end(self):
|
| 227 |
+
if False:
|
| 228 |
+
if self.train_auc.compute() is not None:
|
| 229 |
+
self.log("train/auroc", self.train_auc.compute(), prog_bar=True)
|
| 230 |
+
self.train_auc.reset()
|
| 231 |
+
|
| 232 |
+
def on_validation_epoch_end(self):
|
| 233 |
+
if False:
|
| 234 |
+
if self.val_auc.compute() is not None:
|
| 235 |
+
self.log("val/auroc", self.val_auc.compute(), prog_bar=True)
|
| 236 |
+
self.val_auc.reset()
|
| 237 |
+
|
| 238 |
+
def on_test_epoch_end(self):
|
| 239 |
+
if False:
|
| 240 |
+
if self.test_auc.compute() is not None:
|
| 241 |
+
self.log("test/auroc", self.test_auc.compute(), prog_bar=True)
|
| 242 |
+
self.test_auc.reset()
|
| 243 |
+
|
| 244 |
+
def configure_optimizers(self):
|
| 245 |
+
# AdamW + cosine as a sensible default
|
| 246 |
+
opt = torch.optim.AdamW(
|
| 247 |
+
self.parameters(),
|
| 248 |
+
lr=self.hparams.lr,
|
| 249 |
+
weight_decay=self.hparams.weight_decay,
|
| 250 |
+
)
|
| 251 |
+
# Scheduler optional—comment out if you prefer fixed LR
|
| 252 |
+
sch = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 253 |
+
opt, T_max=max(self.trainer.max_epochs, 1)
|
| 254 |
+
)
|
| 255 |
+
return {
|
| 256 |
+
"optimizer": opt,
|
| 257 |
+
"lr_scheduler": {"scheduler": sch, "interval": "epoch"},
|
| 258 |
+
}
|
dpacman/classifier/model/clustering_data.py
DELETED
|
@@ -1,383 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
import argparse
|
| 3 |
-
import numpy as np
|
| 4 |
-
import pandas as pd
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
import random
|
| 7 |
-
import sys
|
| 8 |
-
import subprocess
|
| 9 |
-
from collections import defaultdict
|
| 10 |
-
|
| 11 |
-
# ─────────────────────────────────────────────────────────────────────────
|
| 12 |
-
# Original helpers (kept; some lightly edited/commented where needed)
|
| 13 |
-
# ─────────────────────────────────────────────────────────────────────────
|
| 14 |
-
|
| 15 |
-
def read_ids_file(p):
|
| 16 |
-
p = Path(p)
|
| 17 |
-
if not p.exists():
|
| 18 |
-
raise FileNotFoundError(f"IDs file not found: {p}")
|
| 19 |
-
return [line.strip() for line in p.open() if line.strip()]
|
| 20 |
-
|
| 21 |
-
def split_embeddings(emb_path, ids_path, out_dir, prefix):
|
| 22 |
-
out_dir = Path(out_dir)
|
| 23 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 24 |
-
|
| 25 |
-
if not Path(emb_path).exists():
|
| 26 |
-
raise FileNotFoundError(f"Embedding file not found: {emb_path}")
|
| 27 |
-
if not Path(ids_path).exists():
|
| 28 |
-
raise FileNotFoundError(f"IDs file not found: {ids_path}")
|
| 29 |
-
|
| 30 |
-
if emb_path.endswith(".npz"):
|
| 31 |
-
data = np.load(emb_path, allow_pickle=True)
|
| 32 |
-
if "embeddings" in data:
|
| 33 |
-
emb = data["embeddings"]
|
| 34 |
-
else:
|
| 35 |
-
raise ValueError(f"{emb_path} missing 'embeddings' key")
|
| 36 |
-
else:
|
| 37 |
-
emb = np.load(emb_path)
|
| 38 |
-
|
| 39 |
-
ids = read_ids_file(ids_path)
|
| 40 |
-
if len(ids) != emb.shape[0]:
|
| 41 |
-
print(f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", file=sys.stderr)
|
| 42 |
-
|
| 43 |
-
mapping = {}
|
| 44 |
-
for i, ident in enumerate(ids):
|
| 45 |
-
if i >= emb.shape[0]:
|
| 46 |
-
print(f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr)
|
| 47 |
-
continue
|
| 48 |
-
arr = emb[i]
|
| 49 |
-
out_file = out_dir / f"{prefix}_{ident}.npy"
|
| 50 |
-
np.save(out_file, arr)
|
| 51 |
-
mapping[ident] = str(out_file)
|
| 52 |
-
return mapping
|
| 53 |
-
|
| 54 |
-
def extract_symbol_from_tf_id(full_id: str) -> str:
|
| 55 |
-
"""
|
| 56 |
-
Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
|
| 57 |
-
return the gene symbol uppercase (e.g., 'ZBTB5').
|
| 58 |
-
"""
|
| 59 |
-
if "|" in full_id:
|
| 60 |
-
try:
|
| 61 |
-
# format sp|Accession|SYMBOL_HUMAN
|
| 62 |
-
genepart = full_id.split("|")[2]
|
| 63 |
-
except IndexError:
|
| 64 |
-
genepart = full_id
|
| 65 |
-
else:
|
| 66 |
-
genepart = full_id
|
| 67 |
-
symbol = genepart.split("_")[0]
|
| 68 |
-
return symbol.upper()
|
| 69 |
-
|
| 70 |
-
def build_tf_symbol_map(tf_map):
|
| 71 |
-
"""
|
| 72 |
-
Build mapping gene_symbol -> list of embedding paths.
|
| 73 |
-
"""
|
| 74 |
-
symbol_map = {}
|
| 75 |
-
for full_id, path in tf_map.items():
|
| 76 |
-
symbol = extract_symbol_from_tf_id(full_id)
|
| 77 |
-
symbol_map.setdefault(symbol, []).append(path)
|
| 78 |
-
return symbol_map
|
| 79 |
-
|
| 80 |
-
def tf_key_from_path(path: str) -> str:
|
| 81 |
-
"""
|
| 82 |
-
Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
|
| 83 |
-
"""
|
| 84 |
-
stem = Path(path).stem # e.g., tf_sp|O15062|ZBTB5_HUMAN
|
| 85 |
-
# remove leading prefix if present (tf_)
|
| 86 |
-
if "_" in stem:
|
| 87 |
-
_, rest = stem.split("_", 1)
|
| 88 |
-
else:
|
| 89 |
-
rest = stem
|
| 90 |
-
return extract_symbol_from_tf_id(rest)
|
| 91 |
-
|
| 92 |
-
def dna_key_from_path(path: str) -> str:
|
| 93 |
-
"""
|
| 94 |
-
Given .../dna_peak42.npy -> 'peak42'
|
| 95 |
-
"""
|
| 96 |
-
stem = Path(path).stem
|
| 97 |
-
if "_" in stem:
|
| 98 |
-
_, rest = stem.split("_", 1)
|
| 99 |
-
else:
|
| 100 |
-
rest = stem
|
| 101 |
-
return rest
|
| 102 |
-
|
| 103 |
-
# ─────────────────────────────────────────────────────────────────────────
|
| 104 |
-
# New helpers for MMseqs clustering & cluster-level splitting
|
| 105 |
-
# ─────────────────────────────────────────────────────────────────────────
|
| 106 |
-
|
| 107 |
-
def write_dna_fasta(df: pd.DataFrame, out_fasta: Path) -> None:
|
| 108 |
-
"""
|
| 109 |
-
Write unique DNA sequences to FASTA using dna_id as header.
|
| 110 |
-
Requires df with columns: dna_id, dna_sequence
|
| 111 |
-
"""
|
| 112 |
-
uniq = df[["dna_id", "dna_sequence"]].drop_duplicates()
|
| 113 |
-
with open(out_fasta, "w") as f:
|
| 114 |
-
for _, row in uniq.iterrows():
|
| 115 |
-
did = row["dna_id"]
|
| 116 |
-
seq = str(row["dna_sequence"]).upper().replace(" ", "").replace("\n", "")
|
| 117 |
-
f.write(f">{did}\n{seq}\n")
|
| 118 |
-
|
| 119 |
-
def run_mmseqs_easy_cluster(
|
| 120 |
-
mmseqs_bin: str,
|
| 121 |
-
fasta: Path,
|
| 122 |
-
out_prefix: Path,
|
| 123 |
-
tmp_dir: Path,
|
| 124 |
-
min_seq_id: float,
|
| 125 |
-
coverage: float,
|
| 126 |
-
cov_mode: int,
|
| 127 |
-
) -> Path:
|
| 128 |
-
"""
|
| 129 |
-
Runs mmseqs easy-cluster on nucleotide sequences.
|
| 130 |
-
Returns the path to a clusters TSV file (creating it if the default one isn't present).
|
| 131 |
-
"""
|
| 132 |
-
tmp_dir.mkdir(parents=True, exist_ok=True)
|
| 133 |
-
out_prefix.parent.mkdir(parents=True, exist_ok=True)
|
| 134 |
-
|
| 135 |
-
cmd = [
|
| 136 |
-
mmseqs_bin, "easy-cluster",
|
| 137 |
-
str(fasta), str(out_prefix), str(tmp_dir),
|
| 138 |
-
"--min-seq-id", str(min_seq_id),
|
| 139 |
-
"-c", str(coverage),
|
| 140 |
-
"--cov-mode", str(cov_mode),
|
| 141 |
-
# You can add performance flags here if needed, e.g.:
|
| 142 |
-
# "--threads", "8"
|
| 143 |
-
]
|
| 144 |
-
print("[i] Running:", " ".join(cmd), flush=True)
|
| 145 |
-
subprocess.run(cmd, check=True)
|
| 146 |
-
|
| 147 |
-
# MMseqs easy-cluster typically writes <out_prefix>_cluster.tsv
|
| 148 |
-
default_tsv = Path(str(out_prefix) + "_cluster.tsv")
|
| 149 |
-
if default_tsv.exists():
|
| 150 |
-
print(f"[i] Found cluster TSV: {default_tsv}")
|
| 151 |
-
return default_tsv
|
| 152 |
-
|
| 153 |
-
# Fallback: try createtsv if default is missing
|
| 154 |
-
# This requires the internal DBs. easy-cluster creates DBs alongside out_prefix.
|
| 155 |
-
# We'll try to locate them and emit a TSV.
|
| 156 |
-
in_db = Path(str(out_prefix) + "_query")
|
| 157 |
-
cl_db = Path(str(out_prefix) + "_cluster")
|
| 158 |
-
out_tsv = Path(str(out_prefix) + "_fallback_cluster.tsv")
|
| 159 |
-
if in_db.exists() and cl_db.exists():
|
| 160 |
-
cmd2 = [mmseqs_bin, "createtsv", str(in_db), str(in_db), str(cl_db), str(out_tsv)]
|
| 161 |
-
print("[i] Creating TSV via createtsv:", " ".join(cmd2), flush=True)
|
| 162 |
-
subprocess.run(cmd2, check=True)
|
| 163 |
-
if out_tsv.exists():
|
| 164 |
-
return out_tsv
|
| 165 |
-
|
| 166 |
-
raise FileNotFoundError("Could not locate clusters TSV from mmseqs. "
|
| 167 |
-
"Expected {default_tsv} or createtsv fallback.")
|
| 168 |
-
|
| 169 |
-
def parse_mmseqs_clusters(tsv_path: Path) -> dict:
|
| 170 |
-
"""
|
| 171 |
-
Parse MMseqs cluster TSV (rep \t member). Returns dna_id -> cluster_rep_id
|
| 172 |
-
"""
|
| 173 |
-
mapping = {}
|
| 174 |
-
with open(tsv_path) as f:
|
| 175 |
-
for line in f:
|
| 176 |
-
parts = line.rstrip("\n").split("\t")
|
| 177 |
-
if len(parts) < 2:
|
| 178 |
-
continue
|
| 179 |
-
rep, member = parts[0], parts[1]
|
| 180 |
-
mapping[member] = rep
|
| 181 |
-
# Some TSVs include rep->rep; if not, ensure rep is mapped to itself:
|
| 182 |
-
if rep not in mapping:
|
| 183 |
-
mapping[rep] = rep
|
| 184 |
-
return mapping
|
| 185 |
-
|
| 186 |
-
def assign_clusters_to_splits(cluster_rep_to_members: dict,
|
| 187 |
-
val_frac: float,
|
| 188 |
-
test_frac: float,
|
| 189 |
-
seed: int = 42):
|
| 190 |
-
"""
|
| 191 |
-
cluster_rep_to_members: dict[rep] = [members...]
|
| 192 |
-
Returns: dict with keys 'train','val','test' mapping to sets of dna_id.
|
| 193 |
-
Ensures all members of a cluster go to the same split.
|
| 194 |
-
"""
|
| 195 |
-
rng = random.Random(seed)
|
| 196 |
-
reps = list(cluster_rep_to_members.keys())
|
| 197 |
-
rng.shuffle(reps)
|
| 198 |
-
|
| 199 |
-
# Greedy-ish fill by total member counts to match desired fractions.
|
| 200 |
-
total = sum(len(cluster_rep_to_members[r]) for r in reps)
|
| 201 |
-
target_val = int(round(total * val_frac))
|
| 202 |
-
target_test = int(round(total * test_frac))
|
| 203 |
-
cur_val = cur_test = 0
|
| 204 |
-
|
| 205 |
-
val_ids, test_ids, train_ids = set(), set(), set()
|
| 206 |
-
for rep in reps:
|
| 207 |
-
members = cluster_rep_to_members[rep]
|
| 208 |
-
c = len(members)
|
| 209 |
-
# Fill val first, then test, then train
|
| 210 |
-
if cur_val + c <= target_val:
|
| 211 |
-
val_ids.update(members); cur_val += c
|
| 212 |
-
elif cur_test + c <= target_test:
|
| 213 |
-
test_ids.update(members); cur_test += c
|
| 214 |
-
else:
|
| 215 |
-
train_ids.update(members)
|
| 216 |
-
|
| 217 |
-
return {"train": train_ids, "val": val_ids, "test": test_ids}
|
| 218 |
-
|
| 219 |
-
# ─────────────────────────────────────────────────────────────────────────
|
| 220 |
-
# Main
|
| 221 |
-
# ─────────────────────────────────────────────────────────────────────────
|
| 222 |
-
|
| 223 |
-
def main():
|
| 224 |
-
parser = argparse.ArgumentParser(
|
| 225 |
-
description="Build TF-DNA pair lists with MMseqs clustering on DNA to prevent split leakage."
|
| 226 |
-
)
|
| 227 |
-
parser.add_argument("--final_csv", required=True, help="final.csv with TF_id and dna_sequence")
|
| 228 |
-
parser.add_argument("--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)")
|
| 229 |
-
parser.add_argument("--dna_ids", required=True, help="IDs file for DNA embeddings (peak*.ids)")
|
| 230 |
-
parser.add_argument("--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)")
|
| 231 |
-
parser.add_argument("--tf_ids", required=True, help="IDs file for TF embeddings (sp|... ids)")
|
| 232 |
-
parser.add_argument("--out_dir", required=True, help="Output directory")
|
| 233 |
-
parser.add_argument("--seed", type=int, default=42)
|
| 234 |
-
|
| 235 |
-
# NEW: MMseqs options & split fractions
|
| 236 |
-
parser.add_argument("--mmseqs_bin", default="mmseqs", help="Path to mmseqs binary")
|
| 237 |
-
parser.add_argument("--min_seq_id", type=float, default=0.9, help="MMseqs --min-seq-id")
|
| 238 |
-
parser.add_argument("--cov", type=float, default=0.8, help="MMseqs -c coverage fraction")
|
| 239 |
-
parser.add_argument("--cov_mode", type=int, default=1, help="MMseqs --cov-mode (1 = coverage of target)")
|
| 240 |
-
parser.add_argument("--val_frac", type=float, default=0.10)
|
| 241 |
-
parser.add_argument("--test_frac", type=float, default=0.10)
|
| 242 |
-
parser.add_argument("--tmp_dir", default=None, help="MMseqs tmp dir (defaults to out_dir/tmp)")
|
| 243 |
-
args = parser.parse_args()
|
| 244 |
-
|
| 245 |
-
random.seed(args.seed)
|
| 246 |
-
out_dir = Path(args.out_dir)
|
| 247 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 248 |
-
|
| 249 |
-
# Load final.csv
|
| 250 |
-
df = pd.read_csv(args.final_csv, dtype=str)
|
| 251 |
-
if "TF_id" not in df.columns or "dna_sequence" not in df.columns:
|
| 252 |
-
raise RuntimeError("final.csv must have columns TF_id and dna_sequence")
|
| 253 |
-
|
| 254 |
-
# Assign dna_id (unique per dna_sequence)
|
| 255 |
-
unique_seqs = df["dna_sequence"].drop_duplicates().tolist()
|
| 256 |
-
seq_to_id = {seq: f"peak{i}" for i, seq in enumerate(unique_seqs)}
|
| 257 |
-
df["dna_id"] = df["dna_sequence"].map(seq_to_id)
|
| 258 |
-
enriched_csv = out_dir / "final_with_dna_id.csv"
|
| 259 |
-
df.to_csv(enriched_csv, index=False)
|
| 260 |
-
print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
|
| 261 |
-
|
| 262 |
-
# Split embeddings into per-item files (unchanged)
|
| 263 |
-
print(f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}")
|
| 264 |
-
dna_map = split_embeddings(args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna")
|
| 265 |
-
print(f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})")
|
| 266 |
-
print(f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}")
|
| 267 |
-
tf_map = split_embeddings(args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf")
|
| 268 |
-
print(f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})")
|
| 269 |
-
|
| 270 |
-
# Build gene-symbol normalized map
|
| 271 |
-
tf_symbol_map = build_tf_symbol_map(tf_map)
|
| 272 |
-
print(f"[i] TF symbol map keys (sample): {list(tf_symbol_map.keys())[:30]}")
|
| 273 |
-
|
| 274 |
-
# Diagnostic overlaps
|
| 275 |
-
norm_tf_in_final = set(t.split("_seq")[0].upper() for t in df["TF_id"].unique())
|
| 276 |
-
available_tf_symbols = set(tf_symbol_map.keys())
|
| 277 |
-
intersect_tf = norm_tf_in_final & available_tf_symbols
|
| 278 |
-
print(f"[i] Unique normalized TF symbols in final.csv: {len(norm_tf_in_final)}")
|
| 279 |
-
print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
|
| 280 |
-
print(f"[i] Intersection count: {len(intersect_tf)}")
|
| 281 |
-
if len(intersect_tf) == 0:
|
| 282 |
-
print("[ERROR] No overlap between normalized TF_id and TF embedding symbols.", file=sys.stderr)
|
| 283 |
-
print("Sample normalized TFs from final.csv:", sorted(list(norm_tf_in_final))[:30], file=sys.stderr)
|
| 284 |
-
print("Sample available TF symbols:", sorted(list(available_tf_symbols))[:30], file=sys.stderr)
|
| 285 |
-
sys.exit(1)
|
| 286 |
-
|
| 287 |
-
dna_ids_final = set(df["dna_id"].unique())
|
| 288 |
-
available_dna_ids = set(dna_map.keys())
|
| 289 |
-
intersect_dna = dna_ids_final & available_dna_ids
|
| 290 |
-
print(f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}")
|
| 291 |
-
if len(intersect_dna) == 0:
|
| 292 |
-
print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
|
| 293 |
-
sys.exit(1)
|
| 294 |
-
|
| 295 |
-
# ── NEW: MMseqs clustering on DNA sequences ───────────────────────────
|
| 296 |
-
fasta_path = out_dir / "dna_unique.fasta"
|
| 297 |
-
write_dna_fasta(df, fasta_path)
|
| 298 |
-
print(f"[i] Wrote FASTA with {df['dna_id'].nunique()} unique sequences → {fasta_path}")
|
| 299 |
-
|
| 300 |
-
tmp_dir = Path(args.tmp_dir) if args.tmp_dir else (out_dir / "mmseqs_tmp")
|
| 301 |
-
cluster_prefix = out_dir / "mmseqs_dna_clusters"
|
| 302 |
-
clusters_tsv = run_mmseqs_easy_cluster(
|
| 303 |
-
mmseqs_bin=args.mmseqs_bin,
|
| 304 |
-
fasta=fasta_path,
|
| 305 |
-
out_prefix=cluster_prefix,
|
| 306 |
-
tmp_dir=tmp_dir,
|
| 307 |
-
min_seq_id=args.min_seq_id,
|
| 308 |
-
coverage=args.cov,
|
| 309 |
-
cov_mode=args.cov_mode,
|
| 310 |
-
)
|
| 311 |
-
|
| 312 |
-
# Parse clusters
|
| 313 |
-
member_to_rep = parse_mmseqs_clusters(clusters_tsv) # dna_id -> rep_id
|
| 314 |
-
# Build rep -> members list
|
| 315 |
-
rep_to_members = defaultdict(list)
|
| 316 |
-
for member, rep in member_to_rep.items():
|
| 317 |
-
rep_to_members[rep].append(member)
|
| 318 |
-
|
| 319 |
-
print(f"[i] Parsed {len(rep_to_members)} clusters from {clusters_tsv}")
|
| 320 |
-
clusters_table = []
|
| 321 |
-
for rep, members in rep_to_members.items():
|
| 322 |
-
for m in members:
|
| 323 |
-
clusters_table.append((m, rep))
|
| 324 |
-
clusters_df = pd.DataFrame(clusters_table, columns=["dna_id", "cluster_id"])
|
| 325 |
-
clusters_df.to_csv(out_dir / "clusters.tsv", sep="\t", index=False)
|
| 326 |
-
print(f"[i] Wrote clusters mapping → {out_dir / 'clusters.tsv'}")
|
| 327 |
-
|
| 328 |
-
# Attach cluster_id back to final df
|
| 329 |
-
df = df.merge(clusters_df, on="dna_id", how="left")
|
| 330 |
-
df.to_csv(out_dir / "final_with_dna_id_and_cluster.csv", index=False)
|
| 331 |
-
print(f"[i] Wrote {out_dir / 'final_with_dna_id_and_cluster.csv'}")
|
| 332 |
-
|
| 333 |
-
# Assign entire clusters to splits
|
| 334 |
-
splits = assign_clusters_to_splits(rep_to_members,
|
| 335 |
-
val_frac=args.val_frac,
|
| 336 |
-
test_frac=args.test_frac,
|
| 337 |
-
seed=args.seed)
|
| 338 |
-
for k in ["train", "val", "test"]:
|
| 339 |
-
print(f"[i] {k}: {len(splits[k])} dna_ids")
|
| 340 |
-
|
| 341 |
-
# ── Build positive pairs only, per split (NO negatives) ───────────────
|
| 342 |
-
positives_by_split = {"train": [], "val": [], "test": []}
|
| 343 |
-
# Build a quick dna_id -> embedding path map
|
| 344 |
-
dnaid_to_path = {did: path for did, path in dna_map.items()}
|
| 345 |
-
|
| 346 |
-
pos_count = 0
|
| 347 |
-
for _, row in df.iterrows():
|
| 348 |
-
tf_raw = row["TF_id"]
|
| 349 |
-
tf_symbol = tf_raw.split("_seq")[0].upper()
|
| 350 |
-
dnaid = row["dna_id"]
|
| 351 |
-
if (tf_symbol not in tf_symbol_map) or (dnaid not in dnaid_to_path):
|
| 352 |
-
continue
|
| 353 |
-
tf_embedding_path = tf_symbol_map[tf_symbol][0] # first embedding per symbol
|
| 354 |
-
|
| 355 |
-
# decide split by dna_id cluster assignment
|
| 356 |
-
if dnaid in splits["train"]:
|
| 357 |
-
positives_by_split["train"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
|
| 358 |
-
elif dnaid in splits["val"]:
|
| 359 |
-
positives_by_split["val"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
|
| 360 |
-
elif dnaid in splits["test"]:
|
| 361 |
-
positives_by_split["test"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
|
| 362 |
-
pos_count += 1
|
| 363 |
-
|
| 364 |
-
print(f"[i] Constructed positives across splits (rows in final.csv iterated: {len(df)})")
|
| 365 |
-
for k in ["train", "val", "test"]:
|
| 366 |
-
print(f"[i] positives[{k}] = {len(positives_by_split[k])}")
|
| 367 |
-
|
| 368 |
-
# # OLD: negatives (kept commented)
|
| 369 |
-
# negatives = []
|
| 370 |
-
# print(f"[i] Sampled {len(negatives)} negatives (neg_per_positive not used)")
|
| 371 |
-
|
| 372 |
-
# Emit split-specific pair lists
|
| 373 |
-
for split in ["train", "val", "test"]:
|
| 374 |
-
out_tsv = out_dir / f"pair_list_{split}.tsv"
|
| 375 |
-
with open(out_tsv, "w") as f:
|
| 376 |
-
for binder_path, glm_path, label in positives_by_split[split]: # + negatives if you add later
|
| 377 |
-
f.write(f"{binder_path}\t{glm_path}\t{label}\n")
|
| 378 |
-
print(f"[i] Wrote {len(positives_by_split[split])} examples to {out_tsv}")
|
| 379 |
-
|
| 380 |
-
print("✅ Done. Cluster-aware splits ready.")
|
| 381 |
-
|
| 382 |
-
if __name__ == "__main__":
|
| 383 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/model/compress_embeddings.py
DELETED
|
@@ -1,54 +0,0 @@
|
|
| 1 |
-
# compress_embeddings.py
|
| 2 |
-
# USAGE: python compress_embeddings.py --input_glob "/path/to/esm_embeddings/*.npy" --output_dir "/path/to/compressed_embeddings" --esm_dim 1280 --out_dim 256
|
| 3 |
-
# --------------
|
| 4 |
-
import os
|
| 5 |
-
import glob
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
from torch import nn
|
| 9 |
-
|
| 10 |
-
class EmbeddingCompressor(nn.Module):
|
| 11 |
-
def __init__(self, input_dim: int = 1280, output_dim: int = 256):
|
| 12 |
-
super().__init__()
|
| 13 |
-
self.fc = nn.Linear(input_dim, output_dim)
|
| 14 |
-
|
| 15 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 16 |
-
"""
|
| 17 |
-
x: (batch, L, input_dim) or (L, input_dim)
|
| 18 |
-
returns: (batch, output_dim) or (output_dim,)
|
| 19 |
-
"""
|
| 20 |
-
if x.dim() == 2:
|
| 21 |
-
# single example: mean over tokens
|
| 22 |
-
x = x.mean(dim=0, keepdim=True) # → (1, input_dim)
|
| 23 |
-
else:
|
| 24 |
-
# batch: mean over tokens
|
| 25 |
-
x = x.mean(dim=1) # → (batch, input_dim)
|
| 26 |
-
return self.fc(x) # → (batch, output_dim)
|
| 27 |
-
|
| 28 |
-
def compress_file(in_path: str, out_path: str, model: EmbeddingCompressor):
|
| 29 |
-
arr = np.load(in_path) # shape (L, D) or (batch, L, D)
|
| 30 |
-
tensor = torch.from_numpy(arr).float()
|
| 31 |
-
with torch.no_grad():
|
| 32 |
-
compressed = model(tensor) # → (batch, 256)
|
| 33 |
-
out = compressed.cpu().numpy()
|
| 34 |
-
np.save(out_path, out)
|
| 35 |
-
print(f"Saved {out_path}")
|
| 36 |
-
|
| 37 |
-
if __name__ == "__main__":
|
| 38 |
-
import argparse
|
| 39 |
-
parser = argparse.ArgumentParser(description="Compress ESM embeddings to 256d")
|
| 40 |
-
parser.add_argument("--input_glob", type=str, required=True,
|
| 41 |
-
help="Glob for your .npy ESM embeddings (e.g. data/esm_*.npy)")
|
| 42 |
-
parser.add_argument("--output_dir", type=str, required=True)
|
| 43 |
-
parser.add_argument("--esm_dim", type=int, default=1280)
|
| 44 |
-
parser.add_argument("--out_dim", type=int, default=256)
|
| 45 |
-
args = parser.parse_args()
|
| 46 |
-
|
| 47 |
-
os.makedirs(args.output_dir, exist_ok=True)
|
| 48 |
-
compressor = EmbeddingCompressor(args.esm_dim, args.out_dim)
|
| 49 |
-
compressor.eval()
|
| 50 |
-
|
| 51 |
-
for fn in glob.glob(args.input_glob):
|
| 52 |
-
base = os.path.basename(fn).replace(".npy", "_256.npy")
|
| 53 |
-
out_path = os.path.join(args.output_dir, base)
|
| 54 |
-
compress_file(fn, out_path, compressor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/model/compute_embeddings.py
DELETED
|
@@ -1,560 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Plug-and-play embedding extraction for:
|
| 3 |
-
• Chromosome sequences (from raw UCSC JSON)
|
| 4 |
-
• TF sequences (transcription_factors.fasta)
|
| 5 |
-
|
| 6 |
-
Usage example (DNA + protein in one go):
|
| 7 |
-
module load miniconda/24.7.1
|
| 8 |
-
conda activate dpacman
|
| 9 |
-
python dpacman/data/compute_embeddings.py \
|
| 10 |
-
--genome-json-dir ../data_files/raw/genomes/hg38 \
|
| 11 |
-
--tf-fasta ../data_files/processed/tfclust/hg38_tf/transcription_factors.fasta \
|
| 12 |
-
--chrom-model caduceus \
|
| 13 |
-
--tf-model esm-dbp \
|
| 14 |
-
--out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \
|
| 15 |
-
--device cuda
|
| 16 |
-
"""
|
| 17 |
-
import os
|
| 18 |
-
import re
|
| 19 |
-
import argparse
|
| 20 |
-
import json
|
| 21 |
-
import numpy as np
|
| 22 |
-
from pathlib import Path
|
| 23 |
-
import torch
|
| 24 |
-
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, pipeline
|
| 25 |
-
import esm
|
| 26 |
-
from Bio import SeqIO
|
| 27 |
-
import time
|
| 28 |
-
import pandas as pd
|
| 29 |
-
from tqdm.auto import tqdm
|
| 30 |
-
import logging, math
|
| 31 |
-
|
| 32 |
-
# ---- model wrappers ----
|
| 33 |
-
|
| 34 |
-
class CaduceusEmbedder:
|
| 35 |
-
def __init__(self, device, chunk_size=131_072, overlap=0):
|
| 36 |
-
"""
|
| 37 |
-
device: 'cpu' or 'cuda'
|
| 38 |
-
chunk_size: max bases (and thus tokens) to send in one forward pass
|
| 39 |
-
overlap: how many bases each window overlaps the previous; 0 = no overlap
|
| 40 |
-
"""
|
| 41 |
-
model_name = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"
|
| 42 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 43 |
-
model_name, trust_remote_code=True
|
| 44 |
-
)
|
| 45 |
-
self.model = AutoModel.from_pretrained(
|
| 46 |
-
model_name, trust_remote_code=True
|
| 47 |
-
).to(device).eval()
|
| 48 |
-
self.device = device
|
| 49 |
-
self.chunk_size = chunk_size
|
| 50 |
-
self.step = chunk_size - overlap
|
| 51 |
-
|
| 52 |
-
def embed(self, seqs):
|
| 53 |
-
"""
|
| 54 |
-
seqs: List[str] of DNA sequences (each <= chunk_size for this test)
|
| 55 |
-
returns: np.ndarray of shape (N, L, D), raw per‐token embeddings
|
| 56 |
-
"""
|
| 57 |
-
# outputs = []
|
| 58 |
-
# for seq in seqs:
|
| 59 |
-
# # --- new: raw per‐token embeddings in one shot ---
|
| 60 |
-
# toks = self.tokenizer(
|
| 61 |
-
# seq,
|
| 62 |
-
# return_tensors="pt",
|
| 63 |
-
# padding=False,
|
| 64 |
-
# truncation=True,
|
| 65 |
-
# max_length=self.chunk_size
|
| 66 |
-
# ).to(self.device)
|
| 67 |
-
# with torch.no_grad():
|
| 68 |
-
# out = self.model(**toks).last_hidden_state # (1, L, D)
|
| 69 |
-
# outputs.append(out.cpu().numpy()[0]) # (L, D)
|
| 70 |
-
|
| 71 |
-
# return np.stack(outputs, axis=0) # (N, L, D)
|
| 72 |
-
outputs = []
|
| 73 |
-
for seq in tqdm(seqs, total=len(seqs), desc="DNA: Caduceus", dynamic_ncols=True):
|
| 74 |
-
toks = self.tokenizer(
|
| 75 |
-
seq,
|
| 76 |
-
return_tensors="pt",
|
| 77 |
-
padding=False,
|
| 78 |
-
truncation=True,
|
| 79 |
-
max_length=self.chunk_size
|
| 80 |
-
).to(self.device)
|
| 81 |
-
with torch.no_grad():
|
| 82 |
-
out = self.model(**toks).last_hidden_state # (1, L, D)
|
| 83 |
-
outputs.append(out.cpu().numpy()[0]) # (L, D)
|
| 84 |
-
return outputs # list of variable-length (L_i, D) arrays
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def benchmark(self, lengths=None):
|
| 88 |
-
"""
|
| 89 |
-
Time embedding on single-sequence of various lengths.
|
| 90 |
-
By default tests [5K,10K,50K,100K,chunk_size].
|
| 91 |
-
"""
|
| 92 |
-
tests = lengths or [5_000, 10_000, 50_000, 100_000, self.chunk_size]
|
| 93 |
-
print(f"→ Benchmarking Caduceus on device={self.device}")
|
| 94 |
-
for sz in tests:
|
| 95 |
-
seq = "A" * sz
|
| 96 |
-
# Warm-up
|
| 97 |
-
_ = self.embed([seq])
|
| 98 |
-
if self.device != "cpu":
|
| 99 |
-
torch.cuda.synchronize()
|
| 100 |
-
t0 = time.perf_counter()
|
| 101 |
-
_ = self.embed([seq])
|
| 102 |
-
if self.device != "cpu":
|
| 103 |
-
torch.cuda.synchronize()
|
| 104 |
-
t1 = time.perf_counter()
|
| 105 |
-
print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms")
|
| 106 |
-
|
| 107 |
-
class SegmentNTEmbedder:
|
| 108 |
-
def __init__(self, device):
|
| 109 |
-
self.tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True)
|
| 110 |
-
self.model = AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True).to(device).eval()
|
| 111 |
-
self.device = device
|
| 112 |
-
|
| 113 |
-
def _adjust_length(self, input_ids):
|
| 114 |
-
bs, L = input_ids.shape
|
| 115 |
-
excl = L - 1
|
| 116 |
-
remainder = (excl) % 4
|
| 117 |
-
if remainder != 0:
|
| 118 |
-
pad_needed = 4 - remainder
|
| 119 |
-
pad_tensor = torch.full((bs, pad_needed), self.tokenizer.pad_token_id, dtype=input_ids.dtype, device=input_ids.device)
|
| 120 |
-
input_ids = torch.cat([input_ids, pad_tensor], dim=1)
|
| 121 |
-
return input_ids
|
| 122 |
-
|
| 123 |
-
def embed(self, seqs, batch_size=16):
|
| 124 |
-
"""
|
| 125 |
-
seqs: List[str]
|
| 126 |
-
Returns: np.ndarray of shape (N, D)
|
| 127 |
-
"""
|
| 128 |
-
all_embeddings = []
|
| 129 |
-
for i in range(0, len(seqs), batch_size):
|
| 130 |
-
batch_seqs = seqs[i : i + batch_size]
|
| 131 |
-
encoded = self.tokenizer.batch_encode_plus(
|
| 132 |
-
batch_seqs,
|
| 133 |
-
return_tensors="pt",
|
| 134 |
-
padding=True,
|
| 135 |
-
truncation=True,
|
| 136 |
-
)
|
| 137 |
-
input_ids = encoded["input_ids"].to(self.device) # (B, L)
|
| 138 |
-
attention_mask = input_ids != self.tokenizer.pad_token_id
|
| 139 |
-
|
| 140 |
-
input_ids = self._adjust_length(input_ids)
|
| 141 |
-
attention_mask = (input_ids != self.tokenizer.pad_token_id)
|
| 142 |
-
|
| 143 |
-
with torch.no_grad():
|
| 144 |
-
outs = self.model(
|
| 145 |
-
input_ids,
|
| 146 |
-
attention_mask=attention_mask,
|
| 147 |
-
output_hidden_states=True,
|
| 148 |
-
return_dict=True,
|
| 149 |
-
)
|
| 150 |
-
if hasattr(outs, "hidden_states") and outs.hidden_states is not None:
|
| 151 |
-
last_hidden = outs.hidden_states[-1] # (B, L, D)
|
| 152 |
-
else:
|
| 153 |
-
last_hidden = outs.last_hidden_state # fallback
|
| 154 |
-
|
| 155 |
-
# Exclude CLS token if present (assume first token) and pool
|
| 156 |
-
pooled = last_hidden[:, 1:, :].mean(dim=1) # (B, D)
|
| 157 |
-
all_embeddings.append(pooled.cpu().numpy())
|
| 158 |
-
|
| 159 |
-
# release fragmentation
|
| 160 |
-
torch.cuda.empty_cache()
|
| 161 |
-
|
| 162 |
-
return np.vstack(all_embeddings) # (N, D)
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
class DNABertEmbedder:
|
| 166 |
-
def __init__(self, device):
|
| 167 |
-
self.tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True)
|
| 168 |
-
self.model = AutoModel.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True).to(device)
|
| 169 |
-
self.device = device
|
| 170 |
-
|
| 171 |
-
def embed(self, seqs):
|
| 172 |
-
embs = []
|
| 173 |
-
for s in seqs:
|
| 174 |
-
tokens = self.tokenizer(s, return_tensors="pt", padding=True)["input_ids"].to(self.device)
|
| 175 |
-
with torch.no_grad():
|
| 176 |
-
out = self.model(tokens).last_hidden_state.mean(1)
|
| 177 |
-
embs.append(out.cpu().numpy())
|
| 178 |
-
return np.vstack(embs)
|
| 179 |
-
|
| 180 |
-
class NucleotideTransformerEmbedder:
|
| 181 |
-
def __init__(self, device):
|
| 182 |
-
# HF “feature-extraction” returns a list of (L, D) arrays for each input
|
| 183 |
-
# device: “cpu” or “cuda”
|
| 184 |
-
self.pipe = pipeline(
|
| 185 |
-
"feature-extraction",
|
| 186 |
-
model="InstaDeepAI/nucleotide-transformer-500m-1000g",
|
| 187 |
-
device= -1 if device=="cpu" else 0 # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0}
|
| 188 |
-
)
|
| 189 |
-
|
| 190 |
-
def embed(self, seqs):
|
| 191 |
-
"""
|
| 192 |
-
seqs: List[str] of raw DNA sequences
|
| 193 |
-
returns: (N, D) array, one D-dim vector per sequence
|
| 194 |
-
"""
|
| 195 |
-
all_embeddings = self.pipe(seqs, truncation=True, padding=True)
|
| 196 |
-
# all_embeddings is a List of shape (L, D) arrays
|
| 197 |
-
pooled = [ np.mean(x, axis=0) for x in all_embeddings ]
|
| 198 |
-
return np.vstack(pooled)
|
| 199 |
-
|
| 200 |
-
# class ESMEmbedder:
|
| 201 |
-
# def __init__(self, device):
|
| 202 |
-
# self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 203 |
-
# self.batch_converter = self.alphabet.get_batch_converter()
|
| 204 |
-
# self.model.to(device).eval()
|
| 205 |
-
# self.device = device
|
| 206 |
-
|
| 207 |
-
# def embed(self, seqs):
|
| 208 |
-
# batch = [(str(i), seq) for i, seq in enumerate(seqs)]
|
| 209 |
-
# _, _, toks = self.batch_converter(batch)
|
| 210 |
-
# toks = toks.to(self.device)
|
| 211 |
-
# with torch.no_grad():
|
| 212 |
-
# results = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 213 |
-
# reps = results["representations"][33]
|
| 214 |
-
# return reps[:, 1:-1].mean(1).cpu().numpy()
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
class ESMEmbedder:
|
| 218 |
-
def __init__(self, device, model_name="esm2_t33_650M_UR50D"):
|
| 219 |
-
# Try to load the specified ESM-2 model; fallback to esm1b if missing
|
| 220 |
-
self.device = device
|
| 221 |
-
try:
|
| 222 |
-
self.model, self.alphabet = getattr(esm.pretrained, model_name)()
|
| 223 |
-
self.is_esm2 = model_name.lower().startswith("esm2")
|
| 224 |
-
except AttributeError:
|
| 225 |
-
# fallback to ESM-1b
|
| 226 |
-
self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 227 |
-
self.is_esm2 = False
|
| 228 |
-
self.batch_converter = self.alphabet.get_batch_converter()
|
| 229 |
-
self.model.to(device).eval()
|
| 230 |
-
# determine max length: esm2 models vary; use default 1024 for esm1b
|
| 231 |
-
self.max_len = 4096 if self.is_esm2 else 1024 # adjust if your esm2 variant has explicit limit
|
| 232 |
-
# for chunking: reserve 2 tokens if model uses BOS/EOS
|
| 233 |
-
self.chunk_size = self.max_len - 2
|
| 234 |
-
self.overlap = self.chunk_size // 4 # 25% overlap to smooth boundaries
|
| 235 |
-
|
| 236 |
-
def _chunk_sequence(self, seq):
|
| 237 |
-
"""
|
| 238 |
-
Return list of possibly overlapping chunks of seq, each <= chunk_size.
|
| 239 |
-
"""
|
| 240 |
-
if len(seq) <= self.chunk_size:
|
| 241 |
-
return [seq]
|
| 242 |
-
step = self.chunk_size - self.overlap
|
| 243 |
-
chunks = []
|
| 244 |
-
for i in range(0, len(seq), step):
|
| 245 |
-
chunk = seq[i : i + self.chunk_size]
|
| 246 |
-
if not chunk:
|
| 247 |
-
break
|
| 248 |
-
chunks.append(chunk)
|
| 249 |
-
return chunks
|
| 250 |
-
|
| 251 |
-
def embed(self, seqs):
|
| 252 |
-
"""
|
| 253 |
-
seqs: List[str] of protein sequences.
|
| 254 |
-
Returns: np.ndarray of shape (N, D) pooled per-sequence embeddings.
|
| 255 |
-
"""
|
| 256 |
-
all_embeddings = []
|
| 257 |
-
for i, seq in enumerate(seqs):
|
| 258 |
-
chunks = self._chunk_sequence(seq)
|
| 259 |
-
chunk_vecs = []
|
| 260 |
-
# process chunks in batch if small number, else sequentially
|
| 261 |
-
for chunk in chunks:
|
| 262 |
-
batch = [(str(i), chunk)]
|
| 263 |
-
_, _, toks = self.batch_converter(batch)
|
| 264 |
-
toks = toks.to(self.device)
|
| 265 |
-
with torch.no_grad():
|
| 266 |
-
results = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 267 |
-
reps = results["representations"][33] # (1, L, D)
|
| 268 |
-
# remove BOS/EOS if present: take 1:-1 if length permits
|
| 269 |
-
if reps.size(1) > 2:
|
| 270 |
-
rep = reps[:, 1:-1].mean(1) # (1, D)
|
| 271 |
-
else:
|
| 272 |
-
rep = reps.mean(1) # fallback
|
| 273 |
-
chunk_vecs.append(rep.squeeze(0)) # (D,)
|
| 274 |
-
if len(chunk_vecs) == 1:
|
| 275 |
-
seq_vec = chunk_vecs[0]
|
| 276 |
-
else:
|
| 277 |
-
# average chunk vectors
|
| 278 |
-
stacked = torch.stack(chunk_vecs, dim=0) # (num_chunks, D)
|
| 279 |
-
seq_vec = stacked.mean(0)
|
| 280 |
-
all_embeddings.append(seq_vec.cpu().numpy())
|
| 281 |
-
return np.vstack(all_embeddings) # (N, D)
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
# class ESMDBPEmbedder:
|
| 285 |
-
# def __init__(self, device):
|
| 286 |
-
# base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 287 |
-
# model_path = (
|
| 288 |
-
# Path(__file__).resolve().parent.parent
|
| 289 |
-
# / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
|
| 290 |
-
# )
|
| 291 |
-
# checkpoint = torch.load(model_path, map_location="cpu")
|
| 292 |
-
# clean_sd = {}
|
| 293 |
-
# for k, v in checkpoint.items():
|
| 294 |
-
# clean_sd[k.replace("module.", "")] = v
|
| 295 |
-
# result = base_model.load_state_dict(clean_sd, strict=False)
|
| 296 |
-
# if result.missing_keys:
|
| 297 |
-
# print(f"[ESMDBP] missing keys: {result.missing_keys}")
|
| 298 |
-
# if result.unexpected_keys:
|
| 299 |
-
# print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}")
|
| 300 |
-
|
| 301 |
-
# self.model = base_model.to(device).eval()
|
| 302 |
-
# self.alphabet = alphabet
|
| 303 |
-
# self.batch_converter = alphabet.get_batch_converter()
|
| 304 |
-
# self.device = device
|
| 305 |
-
|
| 306 |
-
# def embed(self, seqs):
|
| 307 |
-
# batch = [(str(i), seq) for i, seq in enumerate(seqs)]
|
| 308 |
-
# _, _, toks = self.batch_converter(batch)
|
| 309 |
-
# toks = toks.to(self.device)
|
| 310 |
-
# with torch.no_grad():
|
| 311 |
-
# out = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 312 |
-
# reps = out["representations"][33]
|
| 313 |
-
# # skip start/end tokens
|
| 314 |
-
# return reps[:, 1:-1].mean(1).cpu().numpy()
|
| 315 |
-
|
| 316 |
-
class ESMDBPEmbedder:
|
| 317 |
-
def __init__(self, device):
|
| 318 |
-
base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 319 |
-
model_path = (
|
| 320 |
-
Path(__file__).resolve().parent.parent
|
| 321 |
-
/ "pretrained" / "ESM-DBP" / "ESM-DBP.model"
|
| 322 |
-
)
|
| 323 |
-
checkpoint = torch.load(model_path, map_location="cpu")
|
| 324 |
-
clean_sd = {}
|
| 325 |
-
for k, v in checkpoint.items():
|
| 326 |
-
clean_sd[k.replace("module.", "")] = v
|
| 327 |
-
result = base_model.load_state_dict(clean_sd, strict=False)
|
| 328 |
-
if result.missing_keys:
|
| 329 |
-
print(f"[ESMDBP] missing keys: {result.missing_keys}")
|
| 330 |
-
if result.unexpected_keys:
|
| 331 |
-
print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}")
|
| 332 |
-
|
| 333 |
-
self.model = base_model.to(device).eval()
|
| 334 |
-
self.alphabet = alphabet
|
| 335 |
-
self.batch_converter = alphabet.get_batch_converter()
|
| 336 |
-
self.device = device
|
| 337 |
-
self.max_len = 1024 # same limit as esm1b
|
| 338 |
-
self.chunk_size = self.max_len - 2
|
| 339 |
-
self.overlap = self.chunk_size // 4
|
| 340 |
-
|
| 341 |
-
def _chunk_sequence(self, seq):
|
| 342 |
-
if len(seq) <= self.chunk_size:
|
| 343 |
-
return [seq]
|
| 344 |
-
step = self.chunk_size - self.overlap
|
| 345 |
-
chunks = []
|
| 346 |
-
for i in range(0, len(seq), step):
|
| 347 |
-
chunk = seq[i : i + self.chunk_size]
|
| 348 |
-
if not chunk:
|
| 349 |
-
break
|
| 350 |
-
chunks.append(chunk)
|
| 351 |
-
return chunks
|
| 352 |
-
|
| 353 |
-
def embed(self, seqs):
|
| 354 |
-
all_embeddings = []
|
| 355 |
-
for i, seq in enumerate(seqs):
|
| 356 |
-
chunks = self._chunk_sequence(seq)
|
| 357 |
-
chunk_vecs = []
|
| 358 |
-
for chunk in chunks:
|
| 359 |
-
batch = [(str(i), chunk)]
|
| 360 |
-
_, _, toks = self.batch_converter(batch)
|
| 361 |
-
toks = toks.to(self.device)
|
| 362 |
-
with torch.no_grad():
|
| 363 |
-
out = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 364 |
-
reps = out["representations"][33]
|
| 365 |
-
if reps.size(1) > 2:
|
| 366 |
-
rep = reps[:, 1:-1].mean(1)
|
| 367 |
-
else:
|
| 368 |
-
rep = reps.mean(1)
|
| 369 |
-
chunk_vecs.append(rep.squeeze(0))
|
| 370 |
-
if len(chunk_vecs) == 1:
|
| 371 |
-
seq_vec = chunk_vecs[0]
|
| 372 |
-
else:
|
| 373 |
-
stacked = torch.stack(chunk_vecs, dim=0)
|
| 374 |
-
seq_vec = stacked.mean(0)
|
| 375 |
-
all_embeddings.append(seq_vec.cpu().numpy())
|
| 376 |
-
return np.vstack(all_embeddings)
|
| 377 |
-
|
| 378 |
-
class GPNEmbedder:
|
| 379 |
-
def __init__(self, device):
|
| 380 |
-
model_name = "songlab/gpn-msa-sapiens"
|
| 381 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 382 |
-
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
|
| 383 |
-
self.model.to(device)
|
| 384 |
-
self.model.eval()
|
| 385 |
-
self.device = device
|
| 386 |
-
|
| 387 |
-
def embed(self, seqs):
|
| 388 |
-
inputs = self.tokenizer(
|
| 389 |
-
seqs,
|
| 390 |
-
return_tensors="pt",
|
| 391 |
-
padding=True,
|
| 392 |
-
truncation=True
|
| 393 |
-
).to(self.device)
|
| 394 |
-
|
| 395 |
-
with torch.no_grad():
|
| 396 |
-
last_hidden = self.model(**inputs).last_hidden_state
|
| 397 |
-
return last_hidden.mean(dim=1).cpu().numpy()
|
| 398 |
-
|
| 399 |
-
class ProGenEmbedder:
|
| 400 |
-
def __init__(self, device):
|
| 401 |
-
model_name = "jinyuan22/ProGen2-base"
|
| 402 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 403 |
-
self.model = AutoModel.from_pretrained(model_name).to(device).eval()
|
| 404 |
-
self.device = device
|
| 405 |
-
|
| 406 |
-
def embed(self, seqs):
|
| 407 |
-
inputs = self.tokenizer(
|
| 408 |
-
seqs,
|
| 409 |
-
return_tensors="pt",
|
| 410 |
-
padding=True,
|
| 411 |
-
truncation=True
|
| 412 |
-
).to(self.device)
|
| 413 |
-
with torch.no_grad():
|
| 414 |
-
last_hidden = self.model(**inputs).last_hidden_state
|
| 415 |
-
return last_hidden.mean(dim=1).cpu().numpy()
|
| 416 |
-
|
| 417 |
-
# ---- main pipeline ----
|
| 418 |
-
|
| 419 |
-
def get_embedder(name, device, for_dna=True):
|
| 420 |
-
name = name.lower()
|
| 421 |
-
if for_dna:
|
| 422 |
-
if name=="caduceus": return CaduceusEmbedder(device)
|
| 423 |
-
if name=="dnabert": return DNABertEmbedder(device)
|
| 424 |
-
if name=="nucleotide": return NucleotideTransformerEmbedder(device)
|
| 425 |
-
if name=="gpn": return GPNEmbedder(device)
|
| 426 |
-
if name=="segmentnt": return SegmentNTEmbedder(device)
|
| 427 |
-
else:
|
| 428 |
-
if name in ("esm",): return ESMEmbedder(device)
|
| 429 |
-
if name in ("esm-dbp","esm_dbp"): return ESMDBPEmbedder(device)
|
| 430 |
-
if name=="progen": return ProGenEmbedder(device)
|
| 431 |
-
raise ValueError(f"Unknown model {name} (for_dna={for_dna})")
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
def pad_token_embeddings(list_of_arrays, pad_value=0.0):
|
| 435 |
-
"""
|
| 436 |
-
list_of_arrays: list of (L_i, D) numpy arrays
|
| 437 |
-
Returns:
|
| 438 |
-
padded: (N, L_max, D) array
|
| 439 |
-
mask: (N, L_max) boolean array where True = real token, False = padding
|
| 440 |
-
"""
|
| 441 |
-
N = len(list_of_arrays)
|
| 442 |
-
D = list_of_arrays[0].shape[1]
|
| 443 |
-
L_max = max(arr.shape[0] for arr in list_of_arrays)
|
| 444 |
-
padded = np.full((N, L_max, D), pad_value, dtype=list_of_arrays[0].dtype)
|
| 445 |
-
mask = np.zeros((N, L_max), dtype=bool)
|
| 446 |
-
for i, arr in enumerate(list_of_arrays):
|
| 447 |
-
L = arr.shape[0]
|
| 448 |
-
padded[i, :L] = arr
|
| 449 |
-
mask[i, :L] = True
|
| 450 |
-
return padded, mask
|
| 451 |
-
|
| 452 |
-
def embed_and_save(seqs, ids, embedder, out_path):
|
| 453 |
-
embs = embedder.embed(seqs)
|
| 454 |
-
|
| 455 |
-
# Decide whether we got variable-length per-token outputs (list of (L, D))
|
| 456 |
-
is_variable_token = isinstance(embs, (list, tuple)) and len(embs) > 0 and hasattr(embs[0], "shape") and embs[0].ndim == 2
|
| 457 |
-
|
| 458 |
-
if is_variable_token:
|
| 459 |
-
# pad to (N, L_max, D) + mask
|
| 460 |
-
padded, mask = pad_token_embeddings(embs)
|
| 461 |
-
# Save both embeddings and mask together in an .npz for convenience
|
| 462 |
-
np.savez_compressed(out_path.with_suffix(".caduceus.npz"),
|
| 463 |
-
embeddings=padded,
|
| 464 |
-
mask=mask,
|
| 465 |
-
ids=np.array(ids, dtype=object))
|
| 466 |
-
else:
|
| 467 |
-
# fixed shape output, e.g., pooled (N, D)
|
| 468 |
-
array = np.vstack(embs) if isinstance(embs, list) else embs
|
| 469 |
-
np.save(out_path, array)
|
| 470 |
-
with open(out_path.with_suffix(".ids"), "w") as f:
|
| 471 |
-
f.write("\n".join(ids))
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
if __name__=="__main__":
|
| 475 |
-
|
| 476 |
-
p = argparse.ArgumentParser()
|
| 477 |
-
#p.add_argument("--peak_fasta", default="binding_peaks_unique.fa", help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs")
|
| 478 |
-
p.add_argument("--genome-json-dir", default=None, help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes")
|
| 479 |
-
p.add_argument("--skip-dna", action="store_true", help="if set, skip the chromosome embedding step") #if glm embeddings successful but not plm embeddings
|
| 480 |
-
p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
|
| 481 |
-
p.add_argument("--chrom-model", default="caduceus")
|
| 482 |
-
p.add_argument("--tf-model", default="esm-dbp")
|
| 483 |
-
p.add_argument("--out-dir", default="dpacman/model/embeddings")
|
| 484 |
-
p.add_argument("--device", default="cpu")
|
| 485 |
-
args = p.parse_args()
|
| 486 |
-
|
| 487 |
-
os.makedirs(args.out_dir, exist_ok=True)
|
| 488 |
-
device = args.device
|
| 489 |
-
print(device)
|
| 490 |
-
|
| 491 |
-
if not args.skip_dna:
|
| 492 |
-
if args.genome_json_dir == None:
|
| 493 |
-
dna_df = pd.read_parquet('/home/a03-akrishna/DPACMAN/dpacman/model/remap2022_crm_fimo_output_q_processed.parquet', engine='pyarrow')
|
| 494 |
-
#df.to_csv('/home/a03-akrishna/DPACMAN/dpacman/model/remap2022_crm_fimo_output_q_processed.csv', index=False)
|
| 495 |
-
peak_seqs = dna_df["dna_sequence"]
|
| 496 |
-
peak_ids = dna_df["ID"]
|
| 497 |
-
print(f"Embedding {len(peak_seqs)} binding peak sequences from processed remap data", flush=True)
|
| 498 |
-
dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
|
| 499 |
-
out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
|
| 500 |
-
embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
|
| 501 |
-
|
| 502 |
-
# peak_fasta = Path(args.peak_fasta)
|
| 503 |
-
# if peak_fasta.exists():
|
| 504 |
-
# # Load peak sequences from FASTA
|
| 505 |
-
# from Bio import SeqIO
|
| 506 |
-
|
| 507 |
-
# peak_seqs = []
|
| 508 |
-
# peak_ids = []
|
| 509 |
-
# for rec in SeqIO.parse(peak_fasta, "fasta"):
|
| 510 |
-
# peak_ids.append(rec.id)
|
| 511 |
-
# peak_seqs.append(str(rec.seq))
|
| 512 |
-
# print(f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}", flush=True)
|
| 513 |
-
# dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
|
| 514 |
-
# out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
|
| 515 |
-
# embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
|
| 516 |
-
elif args.genome_json_dir:
|
| 517 |
-
# Legacy: load full chromosomes from JSONs (chr1–22, X, Y, M)
|
| 518 |
-
genome_dir = Path(args.genome_json_dir)
|
| 519 |
-
chrom_seqs, chrom_ids = [], []
|
| 520 |
-
primary_pattern = re.compile(r"^hg38_chr(?:[1-9]|1[0-9]|2[0-2]|X|Y|M)\.json$")
|
| 521 |
-
for j in sorted(genome_dir.iterdir()):
|
| 522 |
-
if not primary_pattern.match(j.name):
|
| 523 |
-
continue
|
| 524 |
-
data = json.loads(j.read_text())
|
| 525 |
-
seq = data.get("dna") or data.get("sequence")
|
| 526 |
-
chrom = data.get("chrom") or j.stem.split("_")[-1]
|
| 527 |
-
chrom_seqs.append(seq)
|
| 528 |
-
chrom_ids.append(chrom)
|
| 529 |
-
cutoff = CaduceusEmbedder(device).chunk_size
|
| 530 |
-
long_chroms = [
|
| 531 |
-
(chrom, len(seq))
|
| 532 |
-
for chrom, seq in zip(chrom_ids, chrom_seqs)
|
| 533 |
-
if len(seq) > cutoff
|
| 534 |
-
]
|
| 535 |
-
if long_chroms:
|
| 536 |
-
print("⚠️ Chromosomes exceeding Caduceus max tokens ({}):".format(cutoff))
|
| 537 |
-
for chrom, L in long_chroms:
|
| 538 |
-
print(f" {chrom}: {L} bases")
|
| 539 |
-
else:
|
| 540 |
-
print("All chromosomes ≤ Caduceus limit ({}).".format(cutoff))
|
| 541 |
-
|
| 542 |
-
chrom_embedder = get_embedder(args.chrom_model, device, for_dna=True)
|
| 543 |
-
out_chrom = Path(args.out_dir) / f"chrom_{args.chrom_model}.npy"
|
| 544 |
-
embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom)
|
| 545 |
-
else:
|
| 546 |
-
raise ValueError("No input for DNA embedding: provide a peak FASTA (default binding_peaks_unique.fa) or set --genome-json-dir for chromosome JSONs.")
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
#Load TF sequences
|
| 550 |
-
tf_seqs, tf_ids = [], []
|
| 551 |
-
for record in SeqIO.parse(args.tf_fasta, "fasta"):
|
| 552 |
-
tf_ids.append(record.id)
|
| 553 |
-
tf_seqs.append(str(record.seq))
|
| 554 |
-
|
| 555 |
-
# embed and save
|
| 556 |
-
tf_embedder = get_embedder(args.tf_model, device, for_dna=False)
|
| 557 |
-
out_tf = Path(args.out_dir) / f"tf_{args.tf_model}.npy"
|
| 558 |
-
embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf)
|
| 559 |
-
|
| 560 |
-
print("Done.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/model/extract_tf_symbols.py
DELETED
|
@@ -1,27 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
import pandas as pd
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
|
| 5 |
-
FINAL_CSV = Path("/home/a03-akrishna/DPACMAN/data_files/processed/final.csv")
|
| 6 |
-
OUT_SYMBOLS = Path("tf_symbols.txt")
|
| 7 |
-
|
| 8 |
-
def normalize_tf(tf_id: str) -> str:
|
| 9 |
-
return tf_id.split("_seq")[0].upper()
|
| 10 |
-
|
| 11 |
-
def main():
|
| 12 |
-
df = pd.read_csv(FINAL_CSV, dtype=str)
|
| 13 |
-
if "TF_id" not in df.columns:
|
| 14 |
-
raise RuntimeError("final.csv missing TF_id column")
|
| 15 |
-
tf_raw = df["TF_id"].dropna().unique().tolist()
|
| 16 |
-
normalized = sorted({normalize_tf(t) for t in tf_raw})
|
| 17 |
-
print(f"Unique raw TF_id count: {len(tf_raw)}")
|
| 18 |
-
print(f"Unique normalized TF symbols: {len(normalized)}")
|
| 19 |
-
with open(OUT_SYMBOLS, "w") as f:
|
| 20 |
-
for s in normalized:
|
| 21 |
-
f.write(s + "\n")
|
| 22 |
-
print(f"Wrote normalized TF symbols to {OUT_SYMBOLS}")
|
| 23 |
-
# Optional: show sample
|
| 24 |
-
print("Sample symbols:", normalized[:50])
|
| 25 |
-
|
| 26 |
-
if __name__ == "__main__":
|
| 27 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/model/loss.py
DELETED
|
@@ -1,34 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Define loss functions needed for training the model
|
| 3 |
-
"""
|
| 4 |
-
import torch
|
| 5 |
-
from torch.nn import functional as F
|
| 6 |
-
|
| 7 |
-
def combined_loss_components(logits, targets, peak_thresh=0.5, eps=1e-8):
|
| 8 |
-
probs = torch.sigmoid(logits)
|
| 9 |
-
labels = (targets >= peak_thresh).float()
|
| 10 |
-
non_peak_mask = (labels == 0).float()
|
| 11 |
-
peak_mask = (labels == 1).float()
|
| 12 |
-
|
| 13 |
-
bce_all = F.binary_cross_entropy_with_logits(logits, labels, reduction='none')
|
| 14 |
-
bce_non = (bce_all * non_peak_mask)
|
| 15 |
-
bce_non = bce_non.sum() / (non_peak_mask.sum() + eps)
|
| 16 |
-
|
| 17 |
-
mse_peaks = F.mse_loss(probs * peak_mask, targets * peak_mask, reduction='sum') \
|
| 18 |
-
/ (peak_mask.sum() + eps)
|
| 19 |
-
|
| 20 |
-
t_dist = (targets + eps)
|
| 21 |
-
p_dist = (probs + eps)
|
| 22 |
-
t_dist = t_dist / t_dist.sum(dim=1, keepdim=True)
|
| 23 |
-
p_dist = p_dist / p_dist.sum(dim=1, keepdim=True)
|
| 24 |
-
kl = (t_dist * (t_dist.clamp(min=eps).log() - p_dist.clamp(min=eps).log())).sum(dim=1).mean()
|
| 25 |
-
|
| 26 |
-
return bce_non, kl, mse_peaks, probs
|
| 27 |
-
|
| 28 |
-
def accuracy_percentage(logits, targets, peak_thresh=0.5):
|
| 29 |
-
probs = torch.sigmoid(logits)
|
| 30 |
-
preds_bin = (probs >= 0.5).float()
|
| 31 |
-
labels = (targets >= peak_thresh).float()
|
| 32 |
-
correct = (preds_bin == labels).float().sum()
|
| 33 |
-
total = torch.numel(labels)
|
| 34 |
-
return (correct / max(1, total)).item() * 100.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/model/make_pair_list.py
DELETED
|
@@ -1,220 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
import argparse
|
| 3 |
-
import numpy as np
|
| 4 |
-
import pandas as pd
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
import random
|
| 7 |
-
import sys
|
| 8 |
-
|
| 9 |
-
def read_ids_file(p):
|
| 10 |
-
p = Path(p)
|
| 11 |
-
if not p.exists():
|
| 12 |
-
raise FileNotFoundError(f"IDs file not found: {p}")
|
| 13 |
-
return [line.strip() for line in p.open() if line.strip()]
|
| 14 |
-
|
| 15 |
-
def split_embeddings(emb_path, ids_path, out_dir, prefix):
|
| 16 |
-
out_dir = Path(out_dir)
|
| 17 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 18 |
-
|
| 19 |
-
if not Path(emb_path).exists():
|
| 20 |
-
raise FileNotFoundError(f"Embedding file not found: {emb_path}")
|
| 21 |
-
if not Path(ids_path).exists():
|
| 22 |
-
raise FileNotFoundError(f"IDs file not found: {ids_path}")
|
| 23 |
-
|
| 24 |
-
if emb_path.endswith(".npz"):
|
| 25 |
-
data = np.load(emb_path, allow_pickle=True)
|
| 26 |
-
if "embeddings" in data:
|
| 27 |
-
emb = data["embeddings"]
|
| 28 |
-
else:
|
| 29 |
-
raise ValueError(f"{emb_path} missing 'embeddings' key")
|
| 30 |
-
else:
|
| 31 |
-
emb = np.load(emb_path)
|
| 32 |
-
|
| 33 |
-
ids = read_ids_file(ids_path)
|
| 34 |
-
if len(ids) != emb.shape[0]:
|
| 35 |
-
print(f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", file=sys.stderr)
|
| 36 |
-
|
| 37 |
-
mapping = {}
|
| 38 |
-
for i, ident in enumerate(ids):
|
| 39 |
-
if i >= emb.shape[0]:
|
| 40 |
-
print(f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr)
|
| 41 |
-
continue
|
| 42 |
-
arr = emb[i]
|
| 43 |
-
out_file = out_dir / f"{prefix}_{ident}.npy"
|
| 44 |
-
np.save(out_file, arr)
|
| 45 |
-
mapping[ident] = str(out_file)
|
| 46 |
-
return mapping
|
| 47 |
-
|
| 48 |
-
def extract_symbol_from_tf_id(full_id: str) -> str:
|
| 49 |
-
"""
|
| 50 |
-
Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
|
| 51 |
-
return the gene symbol uppercase (e.g., 'ZBTB5').
|
| 52 |
-
"""
|
| 53 |
-
if "|" in full_id:
|
| 54 |
-
try:
|
| 55 |
-
# format sp|Accession|SYMBOL_HUMAN
|
| 56 |
-
genepart = full_id.split("|")[2]
|
| 57 |
-
except IndexError:
|
| 58 |
-
genepart = full_id
|
| 59 |
-
else:
|
| 60 |
-
genepart = full_id
|
| 61 |
-
symbol = genepart.split("_")[0]
|
| 62 |
-
return symbol.upper()
|
| 63 |
-
|
| 64 |
-
def build_tf_symbol_map(tf_map):
|
| 65 |
-
"""
|
| 66 |
-
Build mapping gene_symbol -> list of embedding paths.
|
| 67 |
-
"""
|
| 68 |
-
symbol_map = {}
|
| 69 |
-
for full_id, path in tf_map.items():
|
| 70 |
-
symbol = extract_symbol_from_tf_id(full_id)
|
| 71 |
-
symbol_map.setdefault(symbol, []).append(path)
|
| 72 |
-
return symbol_map
|
| 73 |
-
|
| 74 |
-
def tf_key_from_path(path: str) -> str:
|
| 75 |
-
"""
|
| 76 |
-
Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
|
| 77 |
-
"""
|
| 78 |
-
stem = Path(path).stem # e.g., tf_sp|O15062|ZBTB5_HUMAN
|
| 79 |
-
# remove leading prefix if present (tf_)
|
| 80 |
-
if "_" in stem:
|
| 81 |
-
_, rest = stem.split("_", 1)
|
| 82 |
-
else:
|
| 83 |
-
rest = stem
|
| 84 |
-
return extract_symbol_from_tf_id(rest)
|
| 85 |
-
|
| 86 |
-
def dna_key_from_path(path: str) -> str:
|
| 87 |
-
"""
|
| 88 |
-
Given .../dna_peak42.npy -> 'peak42'
|
| 89 |
-
"""
|
| 90 |
-
stem = Path(path).stem
|
| 91 |
-
if "_" in stem:
|
| 92 |
-
_, rest = stem.split("_", 1)
|
| 93 |
-
else:
|
| 94 |
-
rest = stem
|
| 95 |
-
return rest
|
| 96 |
-
|
| 97 |
-
def main():
|
| 98 |
-
parser = argparse.ArgumentParser(
|
| 99 |
-
description="Build TF-DNA pair list from final.csv with gene-symbol normalization for TFs."
|
| 100 |
-
)
|
| 101 |
-
parser.add_argument("--final_csv", required=True, help="final.csv with TF_id and dna_sequence")
|
| 102 |
-
parser.add_argument("--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)")
|
| 103 |
-
parser.add_argument("--dna_ids", required=True, help="IDs file for DNA embeddings (e.g., peak*.ids)")
|
| 104 |
-
parser.add_argument("--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)")
|
| 105 |
-
parser.add_argument("--tf_ids", required=True, help="IDs file for TF embeddings (e.g., sp|...|... ids)")
|
| 106 |
-
parser.add_argument("--out_dir", required=True, help="Output directory")
|
| 107 |
-
parser.add_argument("--neg_per_positive", type=int, default=2, help="Negatives per positive (half same-TF, half same-DNA)")
|
| 108 |
-
parser.add_argument("--seed", type=int, default=42)
|
| 109 |
-
args = parser.parse_args()
|
| 110 |
-
|
| 111 |
-
random.seed(args.seed)
|
| 112 |
-
out_dir = Path(args.out_dir)
|
| 113 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 114 |
-
|
| 115 |
-
# Load final.csv
|
| 116 |
-
df = pd.read_csv(args.final_csv, dtype=str)
|
| 117 |
-
if "TF_id" not in df.columns or "dna_sequence" not in df.columns:
|
| 118 |
-
raise RuntimeError("final.csv must have columns TF_id and dna_sequence")
|
| 119 |
-
|
| 120 |
-
# Assign dna_id (unique per dna_sequence)
|
| 121 |
-
unique_seqs = df["dna_sequence"].drop_duplicates().tolist()
|
| 122 |
-
seq_to_id = {seq: f"peak{i}" for i, seq in enumerate(unique_seqs)}
|
| 123 |
-
df["dna_id"] = df["dna_sequence"].map(seq_to_id)
|
| 124 |
-
enriched_csv = out_dir / "final_with_dna_id.csv"
|
| 125 |
-
df.to_csv(enriched_csv, index=False)
|
| 126 |
-
print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
|
| 127 |
-
|
| 128 |
-
# Split embeddings into per-item files
|
| 129 |
-
print(f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}")
|
| 130 |
-
dna_map = split_embeddings(args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna")
|
| 131 |
-
print(f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})")
|
| 132 |
-
print(f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}")
|
| 133 |
-
tf_map = split_embeddings(args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf")
|
| 134 |
-
print(f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})")
|
| 135 |
-
|
| 136 |
-
# Build gene-symbol normalized map
|
| 137 |
-
tf_symbol_map = build_tf_symbol_map(tf_map)
|
| 138 |
-
print(f"[i] TF symbol map keys (sample): {list(tf_symbol_map.keys())[:30]}")
|
| 139 |
-
|
| 140 |
-
# Diagnostic overlaps
|
| 141 |
-
norm_tf_in_final = set(t.split("_seq")[0].upper() for t in df["TF_id"].unique())
|
| 142 |
-
available_tf_symbols = set(tf_symbol_map.keys())
|
| 143 |
-
intersect_tf = norm_tf_in_final & available_tf_symbols
|
| 144 |
-
print(f"[i] Unique normalized TF symbols in final.csv: {len(norm_tf_in_final)}")
|
| 145 |
-
print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
|
| 146 |
-
print(f"[i] Intersection count: {len(intersect_tf)}")
|
| 147 |
-
if len(intersect_tf) == 0:
|
| 148 |
-
print("[ERROR] No overlap between normalized TF_id and TF embedding symbols.", file=sys.stderr)
|
| 149 |
-
print("Sample normalized TFs from final.csv:", sorted(list(norm_tf_in_final))[:30], file=sys.stderr)
|
| 150 |
-
print("Sample available TF symbols:", sorted(list(available_tf_symbols))[:30], file=sys.stderr)
|
| 151 |
-
sys.exit(1)
|
| 152 |
-
|
| 153 |
-
dna_ids_final = set(df["dna_id"].unique())
|
| 154 |
-
available_dna_ids = set(dna_map.keys())
|
| 155 |
-
intersect_dna = dna_ids_final & available_dna_ids
|
| 156 |
-
print(f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}")
|
| 157 |
-
if len(intersect_dna) == 0:
|
| 158 |
-
print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
|
| 159 |
-
sys.exit(1)
|
| 160 |
-
|
| 161 |
-
# Build positive pairs
|
| 162 |
-
positives = []
|
| 163 |
-
for _, row in df.iterrows():
|
| 164 |
-
tf_raw = row["TF_id"]
|
| 165 |
-
tf_symbol = tf_raw.split("_seq")[0].upper()
|
| 166 |
-
dnaid = row["dna_id"]
|
| 167 |
-
if tf_symbol not in tf_symbol_map:
|
| 168 |
-
continue
|
| 169 |
-
if dnaid not in dna_map:
|
| 170 |
-
continue
|
| 171 |
-
# pick the first embedding for that symbol
|
| 172 |
-
tf_embedding_path = tf_symbol_map[tf_symbol][0]
|
| 173 |
-
positives.append((tf_embedding_path, dna_map[dnaid], 1))
|
| 174 |
-
print(f"[i] Constructed {len(positives)} positive pairs after TF symbol resolution")
|
| 175 |
-
|
| 176 |
-
if len(positives) == 0:
|
| 177 |
-
print("[ERROR] No positive pairs could be constructed; aborting.", file=sys.stderr)
|
| 178 |
-
sys.exit(1)
|
| 179 |
-
|
| 180 |
-
# Build negative samples
|
| 181 |
-
all_tf_symbols = sorted(tf_symbol_map.keys())
|
| 182 |
-
all_dnaids = sorted(dna_map.keys())
|
| 183 |
-
positive_set = set()
|
| 184 |
-
for tf_path, dna_path, _ in positives:
|
| 185 |
-
tf_key = tf_key_from_path(tf_path)
|
| 186 |
-
dna_key = dna_key_from_path(dna_path)
|
| 187 |
-
positive_set.add((tf_key, dna_key))
|
| 188 |
-
|
| 189 |
-
negatives = []
|
| 190 |
-
half = args.neg_per_positive // 2
|
| 191 |
-
for tf_path, dna_path, _ in positives:
|
| 192 |
-
tf_key = tf_key_from_path(tf_path)
|
| 193 |
-
dna_key = dna_key_from_path(dna_path)
|
| 194 |
-
# same TF, different DNA
|
| 195 |
-
for _ in range(half):
|
| 196 |
-
candidate_dna = random.choice(all_dnaids)
|
| 197 |
-
if candidate_dna == dna_key or (tf_key, candidate_dna) in positive_set:
|
| 198 |
-
continue
|
| 199 |
-
negatives.append((tf_path, dna_map[candidate_dna], 0))
|
| 200 |
-
# same DNA, different TF
|
| 201 |
-
for _ in range(half):
|
| 202 |
-
candidate_tf_symbol = random.choice(all_tf_symbols)
|
| 203 |
-
if candidate_tf_symbol == tf_key or (candidate_tf_symbol, dna_key) in positive_set:
|
| 204 |
-
continue
|
| 205 |
-
# pick its first embedding
|
| 206 |
-
candidate_tf_path = tf_symbol_map[candidate_tf_symbol][0]
|
| 207 |
-
negatives.append((candidate_tf_path, dna_map[dnaid], 0))
|
| 208 |
-
|
| 209 |
-
print(f"[i] Sampled {len(negatives)} negatives (neg_per_positive={args.neg_per_positive})")
|
| 210 |
-
|
| 211 |
-
# Write pair list
|
| 212 |
-
pair_list_path = out_dir / "pair_list.tsv"
|
| 213 |
-
with open(pair_list_path, "w") as f:
|
| 214 |
-
for binder_path, glm_path, label in positives + negatives:
|
| 215 |
-
# binder=TF, glm=DNA
|
| 216 |
-
f.write(f"{binder_path}\t{glm_path}\t{label}\n")
|
| 217 |
-
print(f"[i] Wrote {len(positives)+len(negatives)} examples to {pair_list_path}")
|
| 218 |
-
|
| 219 |
-
if __name__ == "__main__":
|
| 220 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/model/make_peak_fasta.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
import pandas as pd
|
| 2 |
-
from pathlib import Path
|
| 3 |
-
|
| 4 |
-
df = pd.read_csv("/home/a03-akrishna/DPACMAN/data_files/processed/final.csv", dtype=str) # adjust path if needed
|
| 5 |
-
# get unique sequences
|
| 6 |
-
uniq = df[["dna_sequence"]].drop_duplicates().reset_index(drop=True)
|
| 7 |
-
# make headers: e.g., peak0, peak1, ...
|
| 8 |
-
out_fa = Path("binding_peaks_unique.fa")
|
| 9 |
-
with open(out_fa, "w") as f:
|
| 10 |
-
for i, seq in enumerate(uniq["dna_sequence"]):
|
| 11 |
-
header = f">peak{i}"
|
| 12 |
-
f.write(f"{header}\n{seq}\n")
|
| 13 |
-
print(f"Wrote {len(uniq)} unique binding sequences to {out_fa}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/model_tmp/clustering_data.py
CHANGED
|
@@ -12,12 +12,14 @@ from collections import defaultdict
|
|
| 12 |
# Original helpers (kept; some lightly edited/commented where needed)
|
| 13 |
# ─────────────────────────────────────────────────────────────────────────
|
| 14 |
|
|
|
|
| 15 |
def read_ids_file(p):
|
| 16 |
p = Path(p)
|
| 17 |
if not p.exists():
|
| 18 |
raise FileNotFoundError(f"IDs file not found: {p}")
|
| 19 |
return [line.strip() for line in p.open() if line.strip()]
|
| 20 |
|
|
|
|
| 21 |
def split_embeddings(emb_path, ids_path, out_dir, prefix):
|
| 22 |
out_dir = Path(out_dir)
|
| 23 |
out_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -38,12 +40,17 @@ def split_embeddings(emb_path, ids_path, out_dir, prefix):
|
|
| 38 |
|
| 39 |
ids = read_ids_file(ids_path)
|
| 40 |
if len(ids) != emb.shape[0]:
|
| 41 |
-
print(
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
mapping = {}
|
| 44 |
for i, ident in enumerate(ids):
|
| 45 |
if i >= emb.shape[0]:
|
| 46 |
-
print(
|
|
|
|
|
|
|
| 47 |
continue
|
| 48 |
arr = emb[i]
|
| 49 |
out_file = out_dir / f"{prefix}_{ident}.npy"
|
|
@@ -51,6 +58,7 @@ def split_embeddings(emb_path, ids_path, out_dir, prefix):
|
|
| 51 |
mapping[ident] = str(out_file)
|
| 52 |
return mapping
|
| 53 |
|
|
|
|
| 54 |
def extract_symbol_from_tf_id(full_id: str) -> str:
|
| 55 |
"""
|
| 56 |
Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
|
|
@@ -67,6 +75,7 @@ def extract_symbol_from_tf_id(full_id: str) -> str:
|
|
| 67 |
symbol = genepart.split("_")[0]
|
| 68 |
return symbol.upper()
|
| 69 |
|
|
|
|
| 70 |
def build_tf_symbol_map(tf_map):
|
| 71 |
"""
|
| 72 |
Build mapping gene_symbol -> list of embedding paths.
|
|
@@ -77,6 +86,7 @@ def build_tf_symbol_map(tf_map):
|
|
| 77 |
symbol_map.setdefault(symbol, []).append(path)
|
| 78 |
return symbol_map
|
| 79 |
|
|
|
|
| 80 |
def tf_key_from_path(path: str) -> str:
|
| 81 |
"""
|
| 82 |
Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
|
|
@@ -89,6 +99,7 @@ def tf_key_from_path(path: str) -> str:
|
|
| 89 |
rest = stem
|
| 90 |
return extract_symbol_from_tf_id(rest)
|
| 91 |
|
|
|
|
| 92 |
def dna_key_from_path(path: str) -> str:
|
| 93 |
"""
|
| 94 |
Given .../dna_peak42.npy -> 'peak42'
|
|
@@ -100,10 +111,12 @@ def dna_key_from_path(path: str) -> str:
|
|
| 100 |
rest = stem
|
| 101 |
return rest
|
| 102 |
|
|
|
|
| 103 |
# ─────────────────────────────────────────────────────────────────────────
|
| 104 |
# New helpers for MMseqs clustering & cluster-level splitting
|
| 105 |
# ─────────────────────────────────────────────────────────────────────────
|
| 106 |
|
|
|
|
| 107 |
def write_dna_fasta(df: pd.DataFrame, out_fasta: Path) -> None:
|
| 108 |
"""
|
| 109 |
Write unique DNA sequences to FASTA using dna_id as header.
|
|
@@ -116,6 +129,7 @@ def write_dna_fasta(df: pd.DataFrame, out_fasta: Path) -> None:
|
|
| 116 |
seq = str(row["dna_sequence"]).upper().replace(" ", "").replace("\n", "")
|
| 117 |
f.write(f">{did}\n{seq}\n")
|
| 118 |
|
|
|
|
| 119 |
def run_mmseqs_easy_cluster(
|
| 120 |
mmseqs_bin: str,
|
| 121 |
fasta: Path,
|
|
@@ -133,11 +147,17 @@ def run_mmseqs_easy_cluster(
|
|
| 133 |
out_prefix.parent.mkdir(parents=True, exist_ok=True)
|
| 134 |
|
| 135 |
cmd = [
|
| 136 |
-
mmseqs_bin,
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
# You can add performance flags here if needed, e.g.:
|
| 142 |
# "--threads", "8"
|
| 143 |
]
|
|
@@ -157,14 +177,24 @@ def run_mmseqs_easy_cluster(
|
|
| 157 |
cl_db = Path(str(out_prefix) + "_cluster")
|
| 158 |
out_tsv = Path(str(out_prefix) + "_fallback_cluster.tsv")
|
| 159 |
if in_db.exists() and cl_db.exists():
|
| 160 |
-
cmd2 = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
print("[i] Creating TSV via createtsv:", " ".join(cmd2), flush=True)
|
| 162 |
subprocess.run(cmd2, check=True)
|
| 163 |
if out_tsv.exists():
|
| 164 |
return out_tsv
|
| 165 |
|
| 166 |
-
raise FileNotFoundError(
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
def parse_mmseqs_clusters(tsv_path: Path) -> dict:
|
| 170 |
"""
|
|
@@ -174,7 +204,7 @@ def parse_mmseqs_clusters(tsv_path: Path) -> dict:
|
|
| 174 |
with open(tsv_path) as f:
|
| 175 |
for line in f:
|
| 176 |
parts = line.rstrip("\n").split("\t")
|
| 177 |
-
if len(parts) < 2:
|
| 178 |
continue
|
| 179 |
rep, member = parts[0], parts[1]
|
| 180 |
mapping[member] = rep
|
|
@@ -183,10 +213,10 @@ def parse_mmseqs_clusters(tsv_path: Path) -> dict:
|
|
| 183 |
mapping[rep] = rep
|
| 184 |
return mapping
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
"""
|
| 191 |
cluster_rep_to_members: dict[rep] = [members...]
|
| 192 |
Returns: dict with keys 'train','val','test' mapping to sets of dna_id.
|
|
@@ -208,38 +238,63 @@ def assign_clusters_to_splits(cluster_rep_to_members: dict,
|
|
| 208 |
c = len(members)
|
| 209 |
# Fill val first, then test, then train
|
| 210 |
if cur_val + c <= target_val:
|
| 211 |
-
val_ids.update(members)
|
|
|
|
| 212 |
elif cur_test + c <= target_test:
|
| 213 |
-
test_ids.update(members)
|
|
|
|
| 214 |
else:
|
| 215 |
train_ids.update(members)
|
| 216 |
|
| 217 |
return {"train": train_ids, "val": val_ids, "test": test_ids}
|
| 218 |
|
|
|
|
| 219 |
# ─────────────────────────────────────────────────────────────────────────
|
| 220 |
# Main
|
| 221 |
# ─────────────────────────────────────────────────────────────────────────
|
| 222 |
|
|
|
|
| 223 |
def main():
|
| 224 |
parser = argparse.ArgumentParser(
|
| 225 |
description="Build TF-DNA pair lists with MMseqs clustering on DNA to prevent split leakage."
|
| 226 |
)
|
| 227 |
-
parser.add_argument(
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
parser.add_argument(
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
parser.add_argument("--out_dir", required=True, help="Output directory")
|
| 233 |
parser.add_argument("--seed", type=int, default=42)
|
| 234 |
|
| 235 |
# NEW: MMseqs options & split fractions
|
| 236 |
parser.add_argument("--mmseqs_bin", default="mmseqs", help="Path to mmseqs binary")
|
| 237 |
-
parser.add_argument(
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
parser.add_argument("--val_frac", type=float, default=0.10)
|
| 241 |
parser.add_argument("--test_frac", type=float, default=0.10)
|
| 242 |
-
parser.add_argument(
|
|
|
|
|
|
|
| 243 |
args = parser.parse_args()
|
| 244 |
|
| 245 |
random.seed(args.seed)
|
|
@@ -260,12 +315,24 @@ def main():
|
|
| 260 |
print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
|
| 261 |
|
| 262 |
# Split embeddings into per-item files (unchanged)
|
| 263 |
-
print(
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
# Build gene-symbol normalized map
|
| 271 |
tf_symbol_map = build_tf_symbol_map(tf_map)
|
|
@@ -279,15 +346,28 @@ def main():
|
|
| 279 |
print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
|
| 280 |
print(f"[i] Intersection count: {len(intersect_tf)}")
|
| 281 |
if len(intersect_tf) == 0:
|
| 282 |
-
print(
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
sys.exit(1)
|
| 286 |
|
| 287 |
dna_ids_final = set(df["dna_id"].unique())
|
| 288 |
available_dna_ids = set(dna_map.keys())
|
| 289 |
intersect_dna = dna_ids_final & available_dna_ids
|
| 290 |
-
print(
|
|
|
|
|
|
|
| 291 |
if len(intersect_dna) == 0:
|
| 292 |
print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
|
| 293 |
sys.exit(1)
|
|
@@ -295,7 +375,9 @@ def main():
|
|
| 295 |
# ── NEW: MMseqs clustering on DNA sequences ───────────────────────────
|
| 296 |
fasta_path = out_dir / "dna_unique.fasta"
|
| 297 |
write_dna_fasta(df, fasta_path)
|
| 298 |
-
print(
|
|
|
|
|
|
|
| 299 |
|
| 300 |
tmp_dir = Path(args.tmp_dir) if args.tmp_dir else (out_dir / "mmseqs_tmp")
|
| 301 |
cluster_prefix = out_dir / "mmseqs_dna_clusters"
|
|
@@ -310,7 +392,7 @@ def main():
|
|
| 310 |
)
|
| 311 |
|
| 312 |
# Parse clusters
|
| 313 |
-
member_to_rep = parse_mmseqs_clusters(clusters_tsv)
|
| 314 |
# Build rep -> members list
|
| 315 |
rep_to_members = defaultdict(list)
|
| 316 |
for member, rep in member_to_rep.items():
|
|
@@ -331,10 +413,9 @@ def main():
|
|
| 331 |
print(f"[i] Wrote {out_dir / 'final_with_dna_id_and_cluster.csv'}")
|
| 332 |
|
| 333 |
# Assign entire clusters to splits
|
| 334 |
-
splits = assign_clusters_to_splits(
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
seed=args.seed)
|
| 338 |
for k in ["train", "val", "test"]:
|
| 339 |
print(f"[i] {k}: {len(splits[k])} dna_ids")
|
| 340 |
|
|
@@ -354,14 +435,22 @@ def main():
|
|
| 354 |
|
| 355 |
# decide split by dna_id cluster assignment
|
| 356 |
if dnaid in splits["train"]:
|
| 357 |
-
positives_by_split["train"].append(
|
|
|
|
|
|
|
| 358 |
elif dnaid in splits["val"]:
|
| 359 |
-
positives_by_split["val"].append(
|
|
|
|
|
|
|
| 360 |
elif dnaid in splits["test"]:
|
| 361 |
-
positives_by_split["test"].append(
|
|
|
|
|
|
|
| 362 |
pos_count += 1
|
| 363 |
|
| 364 |
-
print(
|
|
|
|
|
|
|
| 365 |
for k in ["train", "val", "test"]:
|
| 366 |
print(f"[i] positives[{k}] = {len(positives_by_split[k])}")
|
| 367 |
|
|
@@ -373,11 +462,14 @@ def main():
|
|
| 373 |
for split in ["train", "val", "test"]:
|
| 374 |
out_tsv = out_dir / f"pair_list_{split}.tsv"
|
| 375 |
with open(out_tsv, "w") as f:
|
| 376 |
-
for binder_path, glm_path, label in positives_by_split[
|
|
|
|
|
|
|
| 377 |
f.write(f"{binder_path}\t{glm_path}\t{label}\n")
|
| 378 |
print(f"[i] Wrote {len(positives_by_split[split])} examples to {out_tsv}")
|
| 379 |
|
| 380 |
print("✅ Done. Cluster-aware splits ready.")
|
| 381 |
|
|
|
|
| 382 |
if __name__ == "__main__":
|
| 383 |
main()
|
|
|
|
| 12 |
# Original helpers (kept; some lightly edited/commented where needed)
|
| 13 |
# ─────────────────────────────────────────────────────────────────────────
|
| 14 |
|
| 15 |
+
|
| 16 |
def read_ids_file(p):
|
| 17 |
p = Path(p)
|
| 18 |
if not p.exists():
|
| 19 |
raise FileNotFoundError(f"IDs file not found: {p}")
|
| 20 |
return [line.strip() for line in p.open() if line.strip()]
|
| 21 |
|
| 22 |
+
|
| 23 |
def split_embeddings(emb_path, ids_path, out_dir, prefix):
|
| 24 |
out_dir = Path(out_dir)
|
| 25 |
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 40 |
|
| 41 |
ids = read_ids_file(ids_path)
|
| 42 |
if len(ids) != emb.shape[0]:
|
| 43 |
+
print(
|
| 44 |
+
f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}",
|
| 45 |
+
file=sys.stderr,
|
| 46 |
+
)
|
| 47 |
|
| 48 |
mapping = {}
|
| 49 |
for i, ident in enumerate(ids):
|
| 50 |
if i >= emb.shape[0]:
|
| 51 |
+
print(
|
| 52 |
+
f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr
|
| 53 |
+
)
|
| 54 |
continue
|
| 55 |
arr = emb[i]
|
| 56 |
out_file = out_dir / f"{prefix}_{ident}.npy"
|
|
|
|
| 58 |
mapping[ident] = str(out_file)
|
| 59 |
return mapping
|
| 60 |
|
| 61 |
+
|
| 62 |
def extract_symbol_from_tf_id(full_id: str) -> str:
|
| 63 |
"""
|
| 64 |
Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
|
|
|
|
| 75 |
symbol = genepart.split("_")[0]
|
| 76 |
return symbol.upper()
|
| 77 |
|
| 78 |
+
|
| 79 |
def build_tf_symbol_map(tf_map):
|
| 80 |
"""
|
| 81 |
Build mapping gene_symbol -> list of embedding paths.
|
|
|
|
| 86 |
symbol_map.setdefault(symbol, []).append(path)
|
| 87 |
return symbol_map
|
| 88 |
|
| 89 |
+
|
| 90 |
def tf_key_from_path(path: str) -> str:
|
| 91 |
"""
|
| 92 |
Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
|
|
|
|
| 99 |
rest = stem
|
| 100 |
return extract_symbol_from_tf_id(rest)
|
| 101 |
|
| 102 |
+
|
| 103 |
def dna_key_from_path(path: str) -> str:
|
| 104 |
"""
|
| 105 |
Given .../dna_peak42.npy -> 'peak42'
|
|
|
|
| 111 |
rest = stem
|
| 112 |
return rest
|
| 113 |
|
| 114 |
+
|
| 115 |
# ─────────────────────────────────────────────────────────────────────────
|
| 116 |
# New helpers for MMseqs clustering & cluster-level splitting
|
| 117 |
# ─────────────────────────────────────────────────────────────────────────
|
| 118 |
|
| 119 |
+
|
| 120 |
def write_dna_fasta(df: pd.DataFrame, out_fasta: Path) -> None:
|
| 121 |
"""
|
| 122 |
Write unique DNA sequences to FASTA using dna_id as header.
|
|
|
|
| 129 |
seq = str(row["dna_sequence"]).upper().replace(" ", "").replace("\n", "")
|
| 130 |
f.write(f">{did}\n{seq}\n")
|
| 131 |
|
| 132 |
+
|
| 133 |
def run_mmseqs_easy_cluster(
|
| 134 |
mmseqs_bin: str,
|
| 135 |
fasta: Path,
|
|
|
|
| 147 |
out_prefix.parent.mkdir(parents=True, exist_ok=True)
|
| 148 |
|
| 149 |
cmd = [
|
| 150 |
+
mmseqs_bin,
|
| 151 |
+
"easy-cluster",
|
| 152 |
+
str(fasta),
|
| 153 |
+
str(out_prefix),
|
| 154 |
+
str(tmp_dir),
|
| 155 |
+
"--min-seq-id",
|
| 156 |
+
str(min_seq_id),
|
| 157 |
+
"-c",
|
| 158 |
+
str(coverage),
|
| 159 |
+
"--cov-mode",
|
| 160 |
+
str(cov_mode),
|
| 161 |
# You can add performance flags here if needed, e.g.:
|
| 162 |
# "--threads", "8"
|
| 163 |
]
|
|
|
|
| 177 |
cl_db = Path(str(out_prefix) + "_cluster")
|
| 178 |
out_tsv = Path(str(out_prefix) + "_fallback_cluster.tsv")
|
| 179 |
if in_db.exists() and cl_db.exists():
|
| 180 |
+
cmd2 = [
|
| 181 |
+
mmseqs_bin,
|
| 182 |
+
"createtsv",
|
| 183 |
+
str(in_db),
|
| 184 |
+
str(in_db),
|
| 185 |
+
str(cl_db),
|
| 186 |
+
str(out_tsv),
|
| 187 |
+
]
|
| 188 |
print("[i] Creating TSV via createtsv:", " ".join(cmd2), flush=True)
|
| 189 |
subprocess.run(cmd2, check=True)
|
| 190 |
if out_tsv.exists():
|
| 191 |
return out_tsv
|
| 192 |
|
| 193 |
+
raise FileNotFoundError(
|
| 194 |
+
"Could not locate clusters TSV from mmseqs. "
|
| 195 |
+
"Expected {default_tsv} or createtsv fallback."
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
|
| 199 |
def parse_mmseqs_clusters(tsv_path: Path) -> dict:
|
| 200 |
"""
|
|
|
|
| 204 |
with open(tsv_path) as f:
|
| 205 |
for line in f:
|
| 206 |
parts = line.rstrip("\n").split("\t")
|
| 207 |
+
if len(parts) < 2:
|
| 208 |
continue
|
| 209 |
rep, member = parts[0], parts[1]
|
| 210 |
mapping[member] = rep
|
|
|
|
| 213 |
mapping[rep] = rep
|
| 214 |
return mapping
|
| 215 |
|
| 216 |
+
|
| 217 |
+
def assign_clusters_to_splits(
|
| 218 |
+
cluster_rep_to_members: dict, val_frac: float, test_frac: float, seed: int = 42
|
| 219 |
+
):
|
| 220 |
"""
|
| 221 |
cluster_rep_to_members: dict[rep] = [members...]
|
| 222 |
Returns: dict with keys 'train','val','test' mapping to sets of dna_id.
|
|
|
|
| 238 |
c = len(members)
|
| 239 |
# Fill val first, then test, then train
|
| 240 |
if cur_val + c <= target_val:
|
| 241 |
+
val_ids.update(members)
|
| 242 |
+
cur_val += c
|
| 243 |
elif cur_test + c <= target_test:
|
| 244 |
+
test_ids.update(members)
|
| 245 |
+
cur_test += c
|
| 246 |
else:
|
| 247 |
train_ids.update(members)
|
| 248 |
|
| 249 |
return {"train": train_ids, "val": val_ids, "test": test_ids}
|
| 250 |
|
| 251 |
+
|
| 252 |
# ─────────────────────────────────────────────────────────────────────────
|
| 253 |
# Main
|
| 254 |
# ─────────────────────────────────────────────────────────────────────────
|
| 255 |
|
| 256 |
+
|
| 257 |
def main():
|
| 258 |
parser = argparse.ArgumentParser(
|
| 259 |
description="Build TF-DNA pair lists with MMseqs clustering on DNA to prevent split leakage."
|
| 260 |
)
|
| 261 |
+
parser.add_argument(
|
| 262 |
+
"--final_csv", required=True, help="final.csv with TF_id and dna_sequence"
|
| 263 |
+
)
|
| 264 |
+
parser.add_argument(
|
| 265 |
+
"--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)"
|
| 266 |
+
)
|
| 267 |
+
parser.add_argument(
|
| 268 |
+
"--dna_ids", required=True, help="IDs file for DNA embeddings (peak*.ids)"
|
| 269 |
+
)
|
| 270 |
+
parser.add_argument(
|
| 271 |
+
"--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)"
|
| 272 |
+
)
|
| 273 |
+
parser.add_argument(
|
| 274 |
+
"--tf_ids", required=True, help="IDs file for TF embeddings (sp|... ids)"
|
| 275 |
+
)
|
| 276 |
parser.add_argument("--out_dir", required=True, help="Output directory")
|
| 277 |
parser.add_argument("--seed", type=int, default=42)
|
| 278 |
|
| 279 |
# NEW: MMseqs options & split fractions
|
| 280 |
parser.add_argument("--mmseqs_bin", default="mmseqs", help="Path to mmseqs binary")
|
| 281 |
+
parser.add_argument(
|
| 282 |
+
"--min_seq_id", type=float, default=0.9, help="MMseqs --min-seq-id"
|
| 283 |
+
)
|
| 284 |
+
parser.add_argument(
|
| 285 |
+
"--cov", type=float, default=0.8, help="MMseqs -c coverage fraction"
|
| 286 |
+
)
|
| 287 |
+
parser.add_argument(
|
| 288 |
+
"--cov_mode",
|
| 289 |
+
type=int,
|
| 290 |
+
default=1,
|
| 291 |
+
help="MMseqs --cov-mode (1 = coverage of target)",
|
| 292 |
+
)
|
| 293 |
parser.add_argument("--val_frac", type=float, default=0.10)
|
| 294 |
parser.add_argument("--test_frac", type=float, default=0.10)
|
| 295 |
+
parser.add_argument(
|
| 296 |
+
"--tmp_dir", default=None, help="MMseqs tmp dir (defaults to out_dir/tmp)"
|
| 297 |
+
)
|
| 298 |
args = parser.parse_args()
|
| 299 |
|
| 300 |
random.seed(args.seed)
|
|
|
|
| 315 |
print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
|
| 316 |
|
| 317 |
# Split embeddings into per-item files (unchanged)
|
| 318 |
+
print(
|
| 319 |
+
f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}"
|
| 320 |
+
)
|
| 321 |
+
dna_map = split_embeddings(
|
| 322 |
+
args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna"
|
| 323 |
+
)
|
| 324 |
+
print(
|
| 325 |
+
f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})"
|
| 326 |
+
)
|
| 327 |
+
print(
|
| 328 |
+
f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}"
|
| 329 |
+
)
|
| 330 |
+
tf_map = split_embeddings(
|
| 331 |
+
args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf"
|
| 332 |
+
)
|
| 333 |
+
print(
|
| 334 |
+
f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})"
|
| 335 |
+
)
|
| 336 |
|
| 337 |
# Build gene-symbol normalized map
|
| 338 |
tf_symbol_map = build_tf_symbol_map(tf_map)
|
|
|
|
| 346 |
print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
|
| 347 |
print(f"[i] Intersection count: {len(intersect_tf)}")
|
| 348 |
if len(intersect_tf) == 0:
|
| 349 |
+
print(
|
| 350 |
+
"[ERROR] No overlap between normalized TF_id and TF embedding symbols.",
|
| 351 |
+
file=sys.stderr,
|
| 352 |
+
)
|
| 353 |
+
print(
|
| 354 |
+
"Sample normalized TFs from final.csv:",
|
| 355 |
+
sorted(list(norm_tf_in_final))[:30],
|
| 356 |
+
file=sys.stderr,
|
| 357 |
+
)
|
| 358 |
+
print(
|
| 359 |
+
"Sample available TF symbols:",
|
| 360 |
+
sorted(list(available_tf_symbols))[:30],
|
| 361 |
+
file=sys.stderr,
|
| 362 |
+
)
|
| 363 |
sys.exit(1)
|
| 364 |
|
| 365 |
dna_ids_final = set(df["dna_id"].unique())
|
| 366 |
available_dna_ids = set(dna_map.keys())
|
| 367 |
intersect_dna = dna_ids_final & available_dna_ids
|
| 368 |
+
print(
|
| 369 |
+
f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}"
|
| 370 |
+
)
|
| 371 |
if len(intersect_dna) == 0:
|
| 372 |
print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
|
| 373 |
sys.exit(1)
|
|
|
|
| 375 |
# ── NEW: MMseqs clustering on DNA sequences ───────────────────────────
|
| 376 |
fasta_path = out_dir / "dna_unique.fasta"
|
| 377 |
write_dna_fasta(df, fasta_path)
|
| 378 |
+
print(
|
| 379 |
+
f"[i] Wrote FASTA with {df['dna_id'].nunique()} unique sequences → {fasta_path}"
|
| 380 |
+
)
|
| 381 |
|
| 382 |
tmp_dir = Path(args.tmp_dir) if args.tmp_dir else (out_dir / "mmseqs_tmp")
|
| 383 |
cluster_prefix = out_dir / "mmseqs_dna_clusters"
|
|
|
|
| 392 |
)
|
| 393 |
|
| 394 |
# Parse clusters
|
| 395 |
+
member_to_rep = parse_mmseqs_clusters(clusters_tsv) # dna_id -> rep_id
|
| 396 |
# Build rep -> members list
|
| 397 |
rep_to_members = defaultdict(list)
|
| 398 |
for member, rep in member_to_rep.items():
|
|
|
|
| 413 |
print(f"[i] Wrote {out_dir / 'final_with_dna_id_and_cluster.csv'}")
|
| 414 |
|
| 415 |
# Assign entire clusters to splits
|
| 416 |
+
splits = assign_clusters_to_splits(
|
| 417 |
+
rep_to_members, val_frac=args.val_frac, test_frac=args.test_frac, seed=args.seed
|
| 418 |
+
)
|
|
|
|
| 419 |
for k in ["train", "val", "test"]:
|
| 420 |
print(f"[i] {k}: {len(splits[k])} dna_ids")
|
| 421 |
|
|
|
|
| 435 |
|
| 436 |
# decide split by dna_id cluster assignment
|
| 437 |
if dnaid in splits["train"]:
|
| 438 |
+
positives_by_split["train"].append(
|
| 439 |
+
(tf_embedding_path, dnaid_to_path[dnaid], 1)
|
| 440 |
+
)
|
| 441 |
elif dnaid in splits["val"]:
|
| 442 |
+
positives_by_split["val"].append(
|
| 443 |
+
(tf_embedding_path, dnaid_to_path[dnaid], 1)
|
| 444 |
+
)
|
| 445 |
elif dnaid in splits["test"]:
|
| 446 |
+
positives_by_split["test"].append(
|
| 447 |
+
(tf_embedding_path, dnaid_to_path[dnaid], 1)
|
| 448 |
+
)
|
| 449 |
pos_count += 1
|
| 450 |
|
| 451 |
+
print(
|
| 452 |
+
f"[i] Constructed positives across splits (rows in final.csv iterated: {len(df)})"
|
| 453 |
+
)
|
| 454 |
for k in ["train", "val", "test"]:
|
| 455 |
print(f"[i] positives[{k}] = {len(positives_by_split[k])}")
|
| 456 |
|
|
|
|
| 462 |
for split in ["train", "val", "test"]:
|
| 463 |
out_tsv = out_dir / f"pair_list_{split}.tsv"
|
| 464 |
with open(out_tsv, "w") as f:
|
| 465 |
+
for binder_path, glm_path, label in positives_by_split[
|
| 466 |
+
split
|
| 467 |
+
]: # + negatives if you add later
|
| 468 |
f.write(f"{binder_path}\t{glm_path}\t{label}\n")
|
| 469 |
print(f"[i] Wrote {len(positives_by_split[split])} examples to {out_tsv}")
|
| 470 |
|
| 471 |
print("✅ Done. Cluster-aware splits ready.")
|
| 472 |
|
| 473 |
+
|
| 474 |
if __name__ == "__main__":
|
| 475 |
main()
|
dpacman/classifier/model_tmp/compress_embeddings.py
CHANGED
|
@@ -7,6 +7,7 @@ import numpy as np
|
|
| 7 |
import torch
|
| 8 |
from torch import nn
|
| 9 |
|
|
|
|
| 10 |
class EmbeddingCompressor(nn.Module):
|
| 11 |
def __init__(self, input_dim: int = 1280, output_dim: int = 256):
|
| 12 |
super().__init__()
|
|
@@ -19,26 +20,33 @@ class EmbeddingCompressor(nn.Module):
|
|
| 19 |
"""
|
| 20 |
if x.dim() == 2:
|
| 21 |
# single example: mean over tokens
|
| 22 |
-
x = x.mean(dim=0, keepdim=True)
|
| 23 |
else:
|
| 24 |
# batch: mean over tokens
|
| 25 |
-
x = x.mean(dim=1)
|
| 26 |
-
return self.fc(x)
|
|
|
|
| 27 |
|
| 28 |
def compress_file(in_path: str, out_path: str, model: EmbeddingCompressor):
|
| 29 |
-
arr = np.load(in_path)
|
| 30 |
tensor = torch.from_numpy(arr).float()
|
| 31 |
with torch.no_grad():
|
| 32 |
-
compressed = model(tensor)
|
| 33 |
out = compressed.cpu().numpy()
|
| 34 |
np.save(out_path, out)
|
| 35 |
print(f"Saved {out_path}")
|
| 36 |
|
|
|
|
| 37 |
if __name__ == "__main__":
|
| 38 |
import argparse
|
|
|
|
| 39 |
parser = argparse.ArgumentParser(description="Compress ESM embeddings to 256d")
|
| 40 |
-
parser.add_argument(
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
parser.add_argument("--output_dir", type=str, required=True)
|
| 43 |
parser.add_argument("--esm_dim", type=int, default=1280)
|
| 44 |
parser.add_argument("--out_dim", type=int, default=256)
|
|
|
|
| 7 |
import torch
|
| 8 |
from torch import nn
|
| 9 |
|
| 10 |
+
|
| 11 |
class EmbeddingCompressor(nn.Module):
|
| 12 |
def __init__(self, input_dim: int = 1280, output_dim: int = 256):
|
| 13 |
super().__init__()
|
|
|
|
| 20 |
"""
|
| 21 |
if x.dim() == 2:
|
| 22 |
# single example: mean over tokens
|
| 23 |
+
x = x.mean(dim=0, keepdim=True) # → (1, input_dim)
|
| 24 |
else:
|
| 25 |
# batch: mean over tokens
|
| 26 |
+
x = x.mean(dim=1) # → (batch, input_dim)
|
| 27 |
+
return self.fc(x) # → (batch, output_dim)
|
| 28 |
+
|
| 29 |
|
| 30 |
def compress_file(in_path: str, out_path: str, model: EmbeddingCompressor):
|
| 31 |
+
arr = np.load(in_path) # shape (L, D) or (batch, L, D)
|
| 32 |
tensor = torch.from_numpy(arr).float()
|
| 33 |
with torch.no_grad():
|
| 34 |
+
compressed = model(tensor) # → (batch, 256)
|
| 35 |
out = compressed.cpu().numpy()
|
| 36 |
np.save(out_path, out)
|
| 37 |
print(f"Saved {out_path}")
|
| 38 |
|
| 39 |
+
|
| 40 |
if __name__ == "__main__":
|
| 41 |
import argparse
|
| 42 |
+
|
| 43 |
parser = argparse.ArgumentParser(description="Compress ESM embeddings to 256d")
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--input_glob",
|
| 46 |
+
type=str,
|
| 47 |
+
required=True,
|
| 48 |
+
help="Glob for your .npy ESM embeddings (e.g. data/esm_*.npy)",
|
| 49 |
+
)
|
| 50 |
parser.add_argument("--output_dir", type=str, required=True)
|
| 51 |
parser.add_argument("--esm_dim", type=int, default=1280)
|
| 52 |
parser.add_argument("--out_dim", type=int, default=256)
|
dpacman/classifier/model_tmp/compute_embeddings.py
CHANGED
|
@@ -14,6 +14,7 @@ Usage example (DNA + protein in one go):
|
|
| 14 |
--out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \
|
| 15 |
--device cuda
|
| 16 |
"""
|
|
|
|
| 17 |
import os
|
| 18 |
import re
|
| 19 |
import argparse
|
|
@@ -28,6 +29,7 @@ import time
|
|
| 28 |
|
| 29 |
# ---- model wrappers ----
|
| 30 |
|
|
|
|
| 31 |
class CaduceusEmbedder:
|
| 32 |
def __init__(self, device, chunk_size=131_072, overlap=0):
|
| 33 |
"""
|
|
@@ -39,12 +41,14 @@ class CaduceusEmbedder:
|
|
| 39 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 40 |
model_name, trust_remote_code=True
|
| 41 |
)
|
| 42 |
-
self.model =
|
| 43 |
-
model_name, trust_remote_code=True
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
| 46 |
self.chunk_size = chunk_size
|
| 47 |
-
self.step
|
| 48 |
|
| 49 |
def embed(self, seqs):
|
| 50 |
"""
|
|
@@ -73,14 +77,13 @@ class CaduceusEmbedder:
|
|
| 73 |
return_tensors="pt",
|
| 74 |
padding=False,
|
| 75 |
truncation=True,
|
| 76 |
-
max_length=self.chunk_size
|
| 77 |
).to(self.device)
|
| 78 |
with torch.no_grad():
|
| 79 |
out = self.model(**toks).last_hidden_state # (1, L, D)
|
| 80 |
-
outputs.append(out.cpu().numpy()[0])
|
| 81 |
return outputs # list of variable-length (L_i, D) arrays
|
| 82 |
|
| 83 |
-
|
| 84 |
def benchmark(self, lengths=None):
|
| 85 |
"""
|
| 86 |
Time embedding on single-sequence of various lengths.
|
|
@@ -101,10 +104,17 @@ class CaduceusEmbedder:
|
|
| 101 |
t1 = time.perf_counter()
|
| 102 |
print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms")
|
| 103 |
|
|
|
|
| 104 |
class SegmentNTEmbedder:
|
| 105 |
def __init__(self, device):
|
| 106 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
self.device = device
|
| 109 |
|
| 110 |
def _adjust_length(self, input_ids):
|
|
@@ -113,7 +123,12 @@ class SegmentNTEmbedder:
|
|
| 113 |
remainder = (excl) % 4
|
| 114 |
if remainder != 0:
|
| 115 |
pad_needed = 4 - remainder
|
| 116 |
-
pad_tensor = torch.full(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
input_ids = torch.cat([input_ids, pad_tensor], dim=1)
|
| 118 |
return input_ids
|
| 119 |
|
|
@@ -135,7 +150,7 @@ class SegmentNTEmbedder:
|
|
| 135 |
attention_mask = input_ids != self.tokenizer.pad_token_id
|
| 136 |
|
| 137 |
input_ids = self._adjust_length(input_ids)
|
| 138 |
-
attention_mask =
|
| 139 |
|
| 140 |
with torch.no_grad():
|
| 141 |
outs = self.model(
|
|
@@ -161,19 +176,26 @@ class SegmentNTEmbedder:
|
|
| 161 |
|
| 162 |
class DNABertEmbedder:
|
| 163 |
def __init__(self, device):
|
| 164 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
def embed(self, seqs):
|
| 169 |
embs = []
|
| 170 |
for s in seqs:
|
| 171 |
-
tokens = self.tokenizer(s, return_tensors="pt", padding=True)[
|
|
|
|
|
|
|
| 172 |
with torch.no_grad():
|
| 173 |
out = self.model(tokens).last_hidden_state.mean(1)
|
| 174 |
embs.append(out.cpu().numpy())
|
| 175 |
return np.vstack(embs)
|
| 176 |
|
|
|
|
| 177 |
class NucleotideTransformerEmbedder:
|
| 178 |
def __init__(self, device):
|
| 179 |
# HF “feature-extraction” returns a list of (L, D) arrays for each input
|
|
@@ -181,7 +203,9 @@ class NucleotideTransformerEmbedder:
|
|
| 181 |
self.pipe = pipeline(
|
| 182 |
"feature-extraction",
|
| 183 |
model="InstaDeepAI/nucleotide-transformer-500m-1000g",
|
| 184 |
-
device=
|
|
|
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
def embed(self, seqs):
|
|
@@ -191,8 +215,9 @@ class NucleotideTransformerEmbedder:
|
|
| 191 |
"""
|
| 192 |
all_embeddings = self.pipe(seqs, truncation=True, padding=True)
|
| 193 |
# all_embeddings is a List of shape (L, D) arrays
|
| 194 |
-
pooled = [
|
| 195 |
-
return np.vstack(pooled)
|
|
|
|
| 196 |
|
| 197 |
# class ESMEmbedder:
|
| 198 |
# def __init__(self, device):
|
|
@@ -225,7 +250,9 @@ class ESMEmbedder:
|
|
| 225 |
self.batch_converter = self.alphabet.get_batch_converter()
|
| 226 |
self.model.to(device).eval()
|
| 227 |
# determine max length: esm2 models vary; use default 1024 for esm1b
|
| 228 |
-
self.max_len =
|
|
|
|
|
|
|
| 229 |
# for chunking: reserve 2 tokens if model uses BOS/EOS
|
| 230 |
self.chunk_size = self.max_len - 2
|
| 231 |
self.overlap = self.chunk_size // 4 # 25% overlap to smooth boundaries
|
|
@@ -280,7 +307,7 @@ class ESMEmbedder:
|
|
| 280 |
|
| 281 |
# class ESMDBPEmbedder:
|
| 282 |
# def __init__(self, device):
|
| 283 |
-
# base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 284 |
# model_path = (
|
| 285 |
# Path(__file__).resolve().parent.parent
|
| 286 |
# / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
|
|
@@ -310,12 +337,15 @@ class ESMEmbedder:
|
|
| 310 |
# # skip start/end tokens
|
| 311 |
# return reps[:, 1:-1].mean(1).cpu().numpy()
|
| 312 |
|
|
|
|
| 313 |
class ESMDBPEmbedder:
|
| 314 |
def __init__(self, device):
|
| 315 |
base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 316 |
model_path = (
|
| 317 |
Path(__file__).resolve().parent.parent
|
| 318 |
-
/ "pretrained"
|
|
|
|
|
|
|
| 319 |
)
|
| 320 |
checkpoint = torch.load(model_path, map_location="cpu")
|
| 321 |
clean_sd = {}
|
|
@@ -372,6 +402,7 @@ class ESMDBPEmbedder:
|
|
| 372 |
all_embeddings.append(seq_vec.cpu().numpy())
|
| 373 |
return np.vstack(all_embeddings)
|
| 374 |
|
|
|
|
| 375 |
class GPNEmbedder:
|
| 376 |
def __init__(self, device):
|
| 377 |
model_name = "songlab/gpn-msa-sapiens"
|
|
@@ -383,16 +414,14 @@ class GPNEmbedder:
|
|
| 383 |
|
| 384 |
def embed(self, seqs):
|
| 385 |
inputs = self.tokenizer(
|
| 386 |
-
seqs,
|
| 387 |
-
return_tensors="pt",
|
| 388 |
-
padding=True,
|
| 389 |
-
truncation=True
|
| 390 |
).to(self.device)
|
| 391 |
|
| 392 |
with torch.no_grad():
|
| 393 |
last_hidden = self.model(**inputs).last_hidden_state
|
| 394 |
return last_hidden.mean(dim=1).cpu().numpy()
|
| 395 |
|
|
|
|
| 396 |
class ProGenEmbedder:
|
| 397 |
def __init__(self, device):
|
| 398 |
model_name = "jinyuan22/ProGen2-base"
|
|
@@ -402,29 +431,36 @@ class ProGenEmbedder:
|
|
| 402 |
|
| 403 |
def embed(self, seqs):
|
| 404 |
inputs = self.tokenizer(
|
| 405 |
-
seqs,
|
| 406 |
-
return_tensors="pt",
|
| 407 |
-
padding=True,
|
| 408 |
-
truncation=True
|
| 409 |
).to(self.device)
|
| 410 |
with torch.no_grad():
|
| 411 |
last_hidden = self.model(**inputs).last_hidden_state
|
| 412 |
return last_hidden.mean(dim=1).cpu().numpy()
|
| 413 |
|
|
|
|
| 414 |
# ---- main pipeline ----
|
| 415 |
|
|
|
|
| 416 |
def get_embedder(name, device, for_dna=True):
|
| 417 |
name = name.lower()
|
| 418 |
if for_dna:
|
| 419 |
-
if name=="caduceus":
|
| 420 |
-
|
| 421 |
-
if name=="
|
| 422 |
-
|
| 423 |
-
if name=="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
else:
|
| 425 |
-
if name in ("esm",):
|
| 426 |
-
|
| 427 |
-
if name
|
|
|
|
|
|
|
|
|
|
| 428 |
raise ValueError(f"Unknown model {name} (for_dna={for_dna})")
|
| 429 |
|
| 430 |
|
|
@@ -446,20 +482,28 @@ def pad_token_embeddings(list_of_arrays, pad_value=0.0):
|
|
| 446 |
mask[i, :L] = True
|
| 447 |
return padded, mask
|
| 448 |
|
|
|
|
| 449 |
def embed_and_save(seqs, ids, embedder, out_path):
|
| 450 |
embs = embedder.embed(seqs)
|
| 451 |
|
| 452 |
# Decide whether we got variable-length per-token outputs (list of (L, D))
|
| 453 |
-
is_variable_token =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
|
| 455 |
if is_variable_token:
|
| 456 |
# pad to (N, L_max, D) + mask
|
| 457 |
padded, mask = pad_token_embeddings(embs)
|
| 458 |
# Save both embeddings and mask together in an .npz for convenience
|
| 459 |
-
np.savez_compressed(
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
|
|
|
|
|
|
| 463 |
else:
|
| 464 |
# fixed shape output, e.g., pooled (N, D)
|
| 465 |
array = np.vstack(embs) if isinstance(embs, list) else embs
|
|
@@ -468,17 +512,31 @@ def embed_and_save(seqs, ids, embedder, out_path):
|
|
| 468 |
f.write("\n".join(ids))
|
| 469 |
|
| 470 |
|
| 471 |
-
if __name__=="__main__":
|
| 472 |
|
| 473 |
p = argparse.ArgumentParser()
|
| 474 |
-
p.add_argument(
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
p.add_argument(
|
| 480 |
-
|
| 481 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
args = p.parse_args()
|
| 483 |
|
| 484 |
os.makedirs(args.out_dir, exist_ok=True)
|
|
@@ -495,7 +553,10 @@ if __name__=="__main__":
|
|
| 495 |
for rec in SeqIO.parse(peak_fasta, "fasta"):
|
| 496 |
peak_ids.append(rec.id)
|
| 497 |
peak_seqs.append(str(rec.seq))
|
| 498 |
-
print(
|
|
|
|
|
|
|
|
|
|
| 499 |
dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
|
| 500 |
out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
|
| 501 |
embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
|
|
@@ -503,7 +564,9 @@ if __name__=="__main__":
|
|
| 503 |
# Legacy: load full chromosomes from JSONs (chr1–22, X, Y, M)
|
| 504 |
genome_dir = Path(args.genome_json_dir)
|
| 505 |
chrom_seqs, chrom_ids = [], []
|
| 506 |
-
primary_pattern = re.compile(
|
|
|
|
|
|
|
| 507 |
for j in sorted(genome_dir.iterdir()):
|
| 508 |
if not primary_pattern.match(j.name):
|
| 509 |
continue
|
|
@@ -519,7 +582,9 @@ if __name__=="__main__":
|
|
| 519 |
if len(seq) > cutoff
|
| 520 |
]
|
| 521 |
if long_chroms:
|
| 522 |
-
print(
|
|
|
|
|
|
|
| 523 |
for chrom, L in long_chroms:
|
| 524 |
print(f" {chrom}: {L} bases")
|
| 525 |
else:
|
|
@@ -529,10 +594,11 @@ if __name__=="__main__":
|
|
| 529 |
out_chrom = Path(args.out_dir) / f"chrom_{args.chrom_model}.npy"
|
| 530 |
embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom)
|
| 531 |
else:
|
| 532 |
-
raise ValueError(
|
| 533 |
-
|
|
|
|
| 534 |
|
| 535 |
-
#Load TF sequences
|
| 536 |
tf_seqs, tf_ids = [], []
|
| 537 |
for record in SeqIO.parse(args.tf_fasta, "fasta"):
|
| 538 |
tf_ids.append(record.id)
|
|
@@ -543,4 +609,4 @@ if __name__=="__main__":
|
|
| 543 |
out_tf = Path(args.out_dir) / f"tf_{args.tf_model}.npy"
|
| 544 |
embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf)
|
| 545 |
|
| 546 |
-
print("Done.")
|
|
|
|
| 14 |
--out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \
|
| 15 |
--device cuda
|
| 16 |
"""
|
| 17 |
+
|
| 18 |
import os
|
| 19 |
import re
|
| 20 |
import argparse
|
|
|
|
| 29 |
|
| 30 |
# ---- model wrappers ----
|
| 31 |
|
| 32 |
+
|
| 33 |
class CaduceusEmbedder:
|
| 34 |
def __init__(self, device, chunk_size=131_072, overlap=0):
|
| 35 |
"""
|
|
|
|
| 41 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 42 |
model_name, trust_remote_code=True
|
| 43 |
)
|
| 44 |
+
self.model = (
|
| 45 |
+
AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
| 46 |
+
.to(device)
|
| 47 |
+
.eval()
|
| 48 |
+
)
|
| 49 |
+
self.device = device
|
| 50 |
self.chunk_size = chunk_size
|
| 51 |
+
self.step = chunk_size - overlap
|
| 52 |
|
| 53 |
def embed(self, seqs):
|
| 54 |
"""
|
|
|
|
| 77 |
return_tensors="pt",
|
| 78 |
padding=False,
|
| 79 |
truncation=True,
|
| 80 |
+
max_length=self.chunk_size,
|
| 81 |
).to(self.device)
|
| 82 |
with torch.no_grad():
|
| 83 |
out = self.model(**toks).last_hidden_state # (1, L, D)
|
| 84 |
+
outputs.append(out.cpu().numpy()[0]) # (L, D)
|
| 85 |
return outputs # list of variable-length (L_i, D) arrays
|
| 86 |
|
|
|
|
| 87 |
def benchmark(self, lengths=None):
|
| 88 |
"""
|
| 89 |
Time embedding on single-sequence of various lengths.
|
|
|
|
| 104 |
t1 = time.perf_counter()
|
| 105 |
print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms")
|
| 106 |
|
| 107 |
+
|
| 108 |
class SegmentNTEmbedder:
|
| 109 |
def __init__(self, device):
|
| 110 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 111 |
+
"InstaDeepAI/segment_nt", trust_remote_code=True
|
| 112 |
+
)
|
| 113 |
+
self.model = (
|
| 114 |
+
AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True)
|
| 115 |
+
.to(device)
|
| 116 |
+
.eval()
|
| 117 |
+
)
|
| 118 |
self.device = device
|
| 119 |
|
| 120 |
def _adjust_length(self, input_ids):
|
|
|
|
| 123 |
remainder = (excl) % 4
|
| 124 |
if remainder != 0:
|
| 125 |
pad_needed = 4 - remainder
|
| 126 |
+
pad_tensor = torch.full(
|
| 127 |
+
(bs, pad_needed),
|
| 128 |
+
self.tokenizer.pad_token_id,
|
| 129 |
+
dtype=input_ids.dtype,
|
| 130 |
+
device=input_ids.device,
|
| 131 |
+
)
|
| 132 |
input_ids = torch.cat([input_ids, pad_tensor], dim=1)
|
| 133 |
return input_ids
|
| 134 |
|
|
|
|
| 150 |
attention_mask = input_ids != self.tokenizer.pad_token_id
|
| 151 |
|
| 152 |
input_ids = self._adjust_length(input_ids)
|
| 153 |
+
attention_mask = input_ids != self.tokenizer.pad_token_id
|
| 154 |
|
| 155 |
with torch.no_grad():
|
| 156 |
outs = self.model(
|
|
|
|
| 176 |
|
| 177 |
class DNABertEmbedder:
|
| 178 |
def __init__(self, device):
|
| 179 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 180 |
+
"zhihan1996/DNA_bert_6", trust_remote_code=True
|
| 181 |
+
)
|
| 182 |
+
self.model = AutoModel.from_pretrained(
|
| 183 |
+
"zhihan1996/DNA_bert_6", trust_remote_code=True
|
| 184 |
+
).to(device)
|
| 185 |
+
self.device = device
|
| 186 |
|
| 187 |
def embed(self, seqs):
|
| 188 |
embs = []
|
| 189 |
for s in seqs:
|
| 190 |
+
tokens = self.tokenizer(s, return_tensors="pt", padding=True)[
|
| 191 |
+
"input_ids"
|
| 192 |
+
].to(self.device)
|
| 193 |
with torch.no_grad():
|
| 194 |
out = self.model(tokens).last_hidden_state.mean(1)
|
| 195 |
embs.append(out.cpu().numpy())
|
| 196 |
return np.vstack(embs)
|
| 197 |
|
| 198 |
+
|
| 199 |
class NucleotideTransformerEmbedder:
|
| 200 |
def __init__(self, device):
|
| 201 |
# HF “feature-extraction” returns a list of (L, D) arrays for each input
|
|
|
|
| 203 |
self.pipe = pipeline(
|
| 204 |
"feature-extraction",
|
| 205 |
model="InstaDeepAI/nucleotide-transformer-500m-1000g",
|
| 206 |
+
device=(
|
| 207 |
+
-1 if device == "cpu" else 0
|
| 208 |
+
), # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0}
|
| 209 |
)
|
| 210 |
|
| 211 |
def embed(self, seqs):
|
|
|
|
| 215 |
"""
|
| 216 |
all_embeddings = self.pipe(seqs, truncation=True, padding=True)
|
| 217 |
# all_embeddings is a List of shape (L, D) arrays
|
| 218 |
+
pooled = [np.mean(x, axis=0) for x in all_embeddings]
|
| 219 |
+
return np.vstack(pooled)
|
| 220 |
+
|
| 221 |
|
| 222 |
# class ESMEmbedder:
|
| 223 |
# def __init__(self, device):
|
|
|
|
| 250 |
self.batch_converter = self.alphabet.get_batch_converter()
|
| 251 |
self.model.to(device).eval()
|
| 252 |
# determine max length: esm2 models vary; use default 1024 for esm1b
|
| 253 |
+
self.max_len = (
|
| 254 |
+
4096 if self.is_esm2 else 1024
|
| 255 |
+
) # adjust if your esm2 variant has explicit limit
|
| 256 |
# for chunking: reserve 2 tokens if model uses BOS/EOS
|
| 257 |
self.chunk_size = self.max_len - 2
|
| 258 |
self.overlap = self.chunk_size // 4 # 25% overlap to smooth boundaries
|
|
|
|
| 307 |
|
| 308 |
# class ESMDBPEmbedder:
|
| 309 |
# def __init__(self, device):
|
| 310 |
+
# base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 311 |
# model_path = (
|
| 312 |
# Path(__file__).resolve().parent.parent
|
| 313 |
# / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
|
|
|
|
| 337 |
# # skip start/end tokens
|
| 338 |
# return reps[:, 1:-1].mean(1).cpu().numpy()
|
| 339 |
|
| 340 |
+
|
| 341 |
class ESMDBPEmbedder:
|
| 342 |
def __init__(self, device):
|
| 343 |
base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 344 |
model_path = (
|
| 345 |
Path(__file__).resolve().parent.parent
|
| 346 |
+
/ "pretrained"
|
| 347 |
+
/ "ESM-DBP"
|
| 348 |
+
/ "ESM-DBP.model"
|
| 349 |
)
|
| 350 |
checkpoint = torch.load(model_path, map_location="cpu")
|
| 351 |
clean_sd = {}
|
|
|
|
| 402 |
all_embeddings.append(seq_vec.cpu().numpy())
|
| 403 |
return np.vstack(all_embeddings)
|
| 404 |
|
| 405 |
+
|
| 406 |
class GPNEmbedder:
|
| 407 |
def __init__(self, device):
|
| 408 |
model_name = "songlab/gpn-msa-sapiens"
|
|
|
|
| 414 |
|
| 415 |
def embed(self, seqs):
|
| 416 |
inputs = self.tokenizer(
|
| 417 |
+
seqs, return_tensors="pt", padding=True, truncation=True
|
|
|
|
|
|
|
|
|
|
| 418 |
).to(self.device)
|
| 419 |
|
| 420 |
with torch.no_grad():
|
| 421 |
last_hidden = self.model(**inputs).last_hidden_state
|
| 422 |
return last_hidden.mean(dim=1).cpu().numpy()
|
| 423 |
|
| 424 |
+
|
| 425 |
class ProGenEmbedder:
|
| 426 |
def __init__(self, device):
|
| 427 |
model_name = "jinyuan22/ProGen2-base"
|
|
|
|
| 431 |
|
| 432 |
def embed(self, seqs):
|
| 433 |
inputs = self.tokenizer(
|
| 434 |
+
seqs, return_tensors="pt", padding=True, truncation=True
|
|
|
|
|
|
|
|
|
|
| 435 |
).to(self.device)
|
| 436 |
with torch.no_grad():
|
| 437 |
last_hidden = self.model(**inputs).last_hidden_state
|
| 438 |
return last_hidden.mean(dim=1).cpu().numpy()
|
| 439 |
|
| 440 |
+
|
| 441 |
# ---- main pipeline ----
|
| 442 |
|
| 443 |
+
|
| 444 |
def get_embedder(name, device, for_dna=True):
|
| 445 |
name = name.lower()
|
| 446 |
if for_dna:
|
| 447 |
+
if name == "caduceus":
|
| 448 |
+
return CaduceusEmbedder(device)
|
| 449 |
+
if name == "dnabert":
|
| 450 |
+
return DNABertEmbedder(device)
|
| 451 |
+
if name == "nucleotide":
|
| 452 |
+
return NucleotideTransformerEmbedder(device)
|
| 453 |
+
if name == "gpn":
|
| 454 |
+
return GPNEmbedder(device)
|
| 455 |
+
if name == "segmentnt":
|
| 456 |
+
return SegmentNTEmbedder(device)
|
| 457 |
else:
|
| 458 |
+
if name in ("esm",):
|
| 459 |
+
return ESMEmbedder(device)
|
| 460 |
+
if name in ("esm-dbp", "esm_dbp"):
|
| 461 |
+
return ESMDBPEmbedder(device)
|
| 462 |
+
if name == "progen":
|
| 463 |
+
return ProGenEmbedder(device)
|
| 464 |
raise ValueError(f"Unknown model {name} (for_dna={for_dna})")
|
| 465 |
|
| 466 |
|
|
|
|
| 482 |
mask[i, :L] = True
|
| 483 |
return padded, mask
|
| 484 |
|
| 485 |
+
|
| 486 |
def embed_and_save(seqs, ids, embedder, out_path):
|
| 487 |
embs = embedder.embed(seqs)
|
| 488 |
|
| 489 |
# Decide whether we got variable-length per-token outputs (list of (L, D))
|
| 490 |
+
is_variable_token = (
|
| 491 |
+
isinstance(embs, (list, tuple))
|
| 492 |
+
and len(embs) > 0
|
| 493 |
+
and hasattr(embs[0], "shape")
|
| 494 |
+
and embs[0].ndim == 2
|
| 495 |
+
)
|
| 496 |
|
| 497 |
if is_variable_token:
|
| 498 |
# pad to (N, L_max, D) + mask
|
| 499 |
padded, mask = pad_token_embeddings(embs)
|
| 500 |
# Save both embeddings and mask together in an .npz for convenience
|
| 501 |
+
np.savez_compressed(
|
| 502 |
+
out_path.with_suffix(".caduceus.npz"),
|
| 503 |
+
embeddings=padded,
|
| 504 |
+
mask=mask,
|
| 505 |
+
ids=np.array(ids, dtype=object),
|
| 506 |
+
)
|
| 507 |
else:
|
| 508 |
# fixed shape output, e.g., pooled (N, D)
|
| 509 |
array = np.vstack(embs) if isinstance(embs, list) else embs
|
|
|
|
| 512 |
f.write("\n".join(ids))
|
| 513 |
|
| 514 |
|
| 515 |
+
if __name__ == "__main__":
|
| 516 |
|
| 517 |
p = argparse.ArgumentParser()
|
| 518 |
+
p.add_argument(
|
| 519 |
+
"--peak-fasta",
|
| 520 |
+
default="binding_peaks_unique.fa",
|
| 521 |
+
help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs",
|
| 522 |
+
)
|
| 523 |
+
p.add_argument(
|
| 524 |
+
"--genome-json-dir",
|
| 525 |
+
default=None,
|
| 526 |
+
help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes",
|
| 527 |
+
)
|
| 528 |
+
p.add_argument(
|
| 529 |
+
"--skip-dna",
|
| 530 |
+
action="store_true",
|
| 531 |
+
help="if set, skip the chromosome embedding step",
|
| 532 |
+
) # if glm embeddings successful but not plm embeddings
|
| 533 |
+
p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
|
| 534 |
+
p.add_argument("--chrom-model", default="caduceus")
|
| 535 |
+
p.add_argument("--tf-model", default="esm-dbp")
|
| 536 |
+
p.add_argument(
|
| 537 |
+
"--out-dir", default="data_files/processed/tfclust/hg38_tf/embeddings"
|
| 538 |
+
)
|
| 539 |
+
p.add_argument("--device", default="cpu")
|
| 540 |
args = p.parse_args()
|
| 541 |
|
| 542 |
os.makedirs(args.out_dir, exist_ok=True)
|
|
|
|
| 553 |
for rec in SeqIO.parse(peak_fasta, "fasta"):
|
| 554 |
peak_ids.append(rec.id)
|
| 555 |
peak_seqs.append(str(rec.seq))
|
| 556 |
+
print(
|
| 557 |
+
f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}",
|
| 558 |
+
flush=True,
|
| 559 |
+
)
|
| 560 |
dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
|
| 561 |
out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
|
| 562 |
embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
|
|
|
|
| 564 |
# Legacy: load full chromosomes from JSONs (chr1–22, X, Y, M)
|
| 565 |
genome_dir = Path(args.genome_json_dir)
|
| 566 |
chrom_seqs, chrom_ids = [], []
|
| 567 |
+
primary_pattern = re.compile(
|
| 568 |
+
r"^hg38_chr(?:[1-9]|1[0-9]|2[0-2]|X|Y|M)\.json$"
|
| 569 |
+
)
|
| 570 |
for j in sorted(genome_dir.iterdir()):
|
| 571 |
if not primary_pattern.match(j.name):
|
| 572 |
continue
|
|
|
|
| 582 |
if len(seq) > cutoff
|
| 583 |
]
|
| 584 |
if long_chroms:
|
| 585 |
+
print(
|
| 586 |
+
"⚠️ Chromosomes exceeding Caduceus max tokens ({}):".format(cutoff)
|
| 587 |
+
)
|
| 588 |
for chrom, L in long_chroms:
|
| 589 |
print(f" {chrom}: {L} bases")
|
| 590 |
else:
|
|
|
|
| 594 |
out_chrom = Path(args.out_dir) / f"chrom_{args.chrom_model}.npy"
|
| 595 |
embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom)
|
| 596 |
else:
|
| 597 |
+
raise ValueError(
|
| 598 |
+
"No input for DNA embedding: provide a peak FASTA (default binding_peaks_unique.fa) or set --genome-json-dir for chromosome JSONs."
|
| 599 |
+
)
|
| 600 |
|
| 601 |
+
# Load TF sequences
|
| 602 |
tf_seqs, tf_ids = [], []
|
| 603 |
for record in SeqIO.parse(args.tf_fasta, "fasta"):
|
| 604 |
tf_ids.append(record.id)
|
|
|
|
| 609 |
out_tf = Path(args.out_dir) / f"tf_{args.tf_model}.npy"
|
| 610 |
embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf)
|
| 611 |
|
| 612 |
+
print("Done.")
|