svincoff commited on
Commit
6a0f8aa
·
1 Parent(s): c237769

small changes

Browse files
dpacman/data_modules/pair.py CHANGED
@@ -13,12 +13,12 @@ from torch.nn.utils.rnn import pad_sequence
13
  from typing import List, Iterable, Sequence
14
  import sys
15
  import rootutils
 
16
  from dpacman.utils import pylogger
17
 
18
  root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
19
  logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
20
 
21
-
22
  class PreBatchedSampler(Sampler[List[int]]):
23
  """
24
  Yields precomputed batches of indices, e.g. [[3,7,9], [0,1,2], ...].
@@ -508,7 +508,7 @@ def main():
508
  parser.add_argument(
509
  "--dna_shelf_path",
510
  type=str,
511
- default="../data_files/processed/embeddings/fimo_hits_only/baby_peaks_segmentnt_pernuc_with_onehot.shelf",
512
  )
513
  parser.add_argument("--batch_size", type=int, default=4)
514
  parser.add_argument("--num_workers", type=int, default=4)
@@ -537,12 +537,12 @@ def main():
537
  )
538
 
539
  # ---- Train ----
540
- dm.setup(stage="train")
541
  train_dl = dm.train_dataloader()
542
- _peek_batches(train_dl, n_batches=args.n_batches, tag="train")
543
 
544
  # ---- Val ----
545
- dm.setup(stage="val")
546
  val_dl = dm.val_dataloader()
547
  _peek_batches(val_dl, n_batches=1, tag="val") # usually enough to sanity-check
548
 
 
13
  from typing import List, Iterable, Sequence
14
  import sys
15
  import rootutils
16
+ import logging
17
  from dpacman.utils import pylogger
18
 
19
  root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
20
  logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
21
 
 
22
  class PreBatchedSampler(Sampler[List[int]]):
23
  """
24
  Yields precomputed batches of indices, e.g. [[3,7,9], [0,1,2], ...].
 
508
  parser.add_argument(
509
  "--dna_shelf_path",
510
  type=str,
511
+ default="../data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf",
512
  )
513
  parser.add_argument("--batch_size", type=int, default=4)
514
  parser.add_argument("--num_workers", type=int, default=4)
 
537
  )
538
 
539
  # ---- Train ----
540
+ dm.setup(stage="fit")
541
  train_dl = dm.train_dataloader()
542
+ _peek_batches(train_dl, n_batches=args.n_batches, tag="fit")
543
 
544
  # ---- Val ----
545
+ dm.setup(stage="validate")
546
  val_dl = dm.val_dataloader()
547
  _peek_batches(val_dl, n_batches=1, tag="val") # usually enough to sanity-check
548
 
dpacman/scripts/run_train.sh CHANGED
@@ -22,6 +22,7 @@ nohup python -u -m scripts.train \
22
  data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
23
  data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
24
  model.glm_input_dim=256 \
 
25
  model.compressed_dim=1029 \
26
  > "${run_dir}/run.log" 2>&1 &
27
 
 
22
  data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
23
  data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
24
  model.glm_input_dim=256 \
25
+ model.lr=1e-5 \
26
  model.compressed_dim=1029 \
27
  > "${run_dir}/run.log" 2>&1 &
28