DPACMAN / dpacman /data_tasks /fimo /post_fimo.py
svincoff's picture
training works
29899b4
#!/usr/bin/env python3
import os
from pathlib import Path
import multiprocessing as mp
import numpy as np
import pandas as pd
import json
import rootutils
import polars as pl
from omegaconf import DictConfig
from hydra.core.hydra_config import HydraConfig
import logging
from dpacman.data_tasks.fimo.pre_fimo import load_chrom_dna
import rootutils
from dpacman.utils import pylogger
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
def normalize_array(
arr: np.ndarray, max_chipseq_score: int = 1000, jaspar_boost: int = 100
) -> np.ndarray:
normalization_factor = max_chipseq_score + jaspar_boost
return arr / normalization_factor
def format_sig(sig_vals, decimals=4, atol=0.0, rtol=1e-5):
a = np.asarray(sig_vals, dtype=float)
scale = 10.0**decimals
thresh = 0.5 / scale # 0.00005 for 4 dp
# Would display as 0.0000 or 1.0000 at given precision?
m0 = np.isclose(a, 0.0, atol=atol, rtol=rtol) | (np.abs(a) <= thresh)
m1 = np.isclose(a, 1.0, atol=atol, rtol=rtol) | (np.abs(a - 1.0) <= thresh)
out = np.char.mod(f"%.{decimals}f", a)
out = np.where(m0, "0", out)
out = np.where(m1 & ~m0, "1", out) # don’t overwrite any zeros
return ",".join(out.tolist())
def _safe_process(task):
"""
Returns:
("ok", <path-to-output>) on success
("err", (chrom, msg, traceback)) on failure
"""
import traceback as tb
chrom = task[0]
try:
out_path = _process_one_chrom_folder(task) # MUST return a path (str/Path)
return ("ok", str(out_path))
except Exception as e:
return ("err", (chrom, repr(e), tb.format_exc()))
def discover_chrom_folders(fimo_out_dir: Path) -> list[str]:
return sorted(
name
for name in os.listdir(fimo_out_dir)
if name.startswith("chrom") and (fimo_out_dir / name / "final.csv").exists()
)
def _process_one_row(row, dna: str, jaspar_boost: int = 100) -> dict:
# row order: TR, chrom, cstart, cend, peak_s, peak_e, chipscore, jaspar
trname, chrom, cstart, cend, peak_s, peak_e, chipscore, jaspar = row
# very few chipscores are > 1000. standardize by setting >1000 to a max score
if chipscore >= 1000:
chipscore = 1000
seq = dna[cstart:cend]
L = len(seq)
scores = np.zeros(L)
# ChIP peak
ps = peak_s - cstart
pe = peak_e - cstart
peak_seq = ""
if ps < L and pe > 0:
scores[max(ps, 0) : min(pe, L)] = chipscore
peak_seq = seq[max(ps, 0) : min(pe, L)]
# JASPAR hits (+jaspar_boost)
# only run if the peak is not np.nan
total_jaspar = 0
if isinstance(jaspar, str) and jaspar.strip():
for hit in jaspar.split(","):
total_jaspar += 1
hs, he = hit.split("-")
hs_i = max(int(hs) - cstart, 0)
he_i = min(int(he) - cstart, L)
if hs_i < he_i:
scores[hs_i:he_i] = chipscore + jaspar_boost
score_str = ",".join(map(str, [int(x) for x in scores.tolist()]))
# sig_vals = normalize_array(scores.astype(np.float32))
# store out to 4 decimal places unless it's 0
# score_sig = format_sig(sig_vals)
return {
"chrom": chrom,
"tr_name": trname,
"dna_sequence": seq,
"peak_sequence": peak_seq,
"chipscore": chipscore,
"total_jaspar_hits": total_jaspar,
"scores": score_str,
}
def _process_one_chrom_folder(task) -> pd.DataFrame:
"""Runs inside a worker process. Reads one chrom’s final.csv, loads DNA once, builds records."""
(
chrom_folder,
fimo_out_dir_str,
json_dir,
jaspar_boost,
output_parts_folder,
keep_fimo_only,
) = task
# make unique logger for this process
log_dir = Path(HydraConfig.get().run.dir) / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
output_parts_folder.mkdir(parents=True, exist_ok=True)
log_file = log_dir / f"fimo_{chrom_folder}.log"
wlogger = logging.getLogger(f"fimo_{chrom_folder}")
wlogger.setLevel(logging.DEBUG)
wlogger.propagate = False # Don't double-log to root
if not any(isinstance(h, logging.FileHandler) for h in wlogger.handlers):
fh = logging.FileHandler(log_file, mode="w", encoding="utf-8")
fh.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
wlogger.addHandler(fh)
fimo_out_dir = Path(fimo_out_dir_str)
final_csv = fimo_out_dir / chrom_folder / "final.csv"
if not final_csv.exists():
return pd.DataFrame()
usecols = [
"TR",
"#chrom",
"contextStart",
"contextEnd",
"ChIPStart",
"ChIPEnd",
"chipscore",
"jaspar",
]
df = pd.read_csv(final_csv, usecols=usecols)
if df.empty:
return pd.DataFrame()
if keep_fimo_only:
logger.info(f"keep_fimo_only=True. Starting with {len(df)} rows.")
df = df.loc[~df["jaspar"].isna()].reset_index(drop=True)
logger.info(f"After keeping fimo hits only: {len(df)} rows remain.")
# Normalize dtypes up-front
df["#chrom"] = df["#chrom"].astype(str)
for col in ("contextStart", "contextEnd", "ChIPStart", "ChIPEnd", "chipscore"):
df[col] = pd.to_numeric(df[col], downcast="integer")
chrom = df["#chrom"].iloc[0]
dna_cache = {}
dna = load_chrom_dna(
str(chrom), dna_cache, json_dir
).upper() # just capitalize it for training
wlogger.info(f"Loaded DNA for {chrom}, length {len(dna)}")
records = []
# rename to make processing easier
rename = {
"#chrom": "chrom",
"contextStart": "cstart",
"contextEnd": "cend",
"ChIPStart": "peak_s",
"ChIPEnd": "peak_e",
"TR": "tr_name",
}
df = df.rename(columns=rename)
# (Optional) ensure numeric dtypes; will raise if non-numeric
for col in ["cstart", "cend", "peak_s", "peak_e", "chipscore"]:
df[col] = pd.to_numeric(df[col], errors="raise")
total = len(df)
last_decile = 0
for i, row in enumerate(df.itertuples(index=False), start=1):
records.append(
_process_one_row(
(
row.tr_name,
row.chrom,
int(row.cstart),
int(row.cend),
int(row.peak_s),
int(row.peak_e),
int(row.chipscore),
row.jaspar,
),
dna,
jaspar_boost,
)
)
# progress every ~10%
decile = (i * 10) // max(total, 1)
if decile > last_decile:
last_decile = decile
wlogger.info("Progress: %d%% (%d/%d)", decile * 10, i, total)
wlogger.info(f"Completed processing {len(records)} rows for {chrom_folder}")
# make into a DataFrame and save
records_df = pd.DataFrame.from_records(records)
savepath = output_parts_folder / f"{chrom_folder}_processed.csv"
records_df.to_csv(savepath, index=False)
wlogger.info(f"Saved records to {savepath}")
return savepath
def build_dataset_fast_mp(
fimo_out_dir: Path,
json_dir: str,
debug: bool,
max_workers: int | None,
jaspar_boost: int = 100,
output_parts_folder: str = None,
keep_fimo_only: bool = False,
) -> pd.DataFrame:
"""
Multiprocessing to build final dataset across chromosomes
"""
chrom_folders = discover_chrom_folders(fimo_out_dir)
if not chrom_folders:
logger.warning(f"No chrom* folders with final.csv under {fimo_out_dir}")
return []
if debug:
# keep chromY if present, otherwise just the first
chrom_folders = [c for c in chrom_folders if c == "chromY"] or chrom_folders[:1]
logger.info(f"DEBUG MODE: considering {chrom_folders[0]} only")
tasks = [
(
cf,
str(fimo_out_dir),
json_dir,
jaspar_boost,
output_parts_folder,
keep_fimo_only,
)
for cf in chrom_folders
]
def _collect(status, payload, good_paths, errs):
if status == "ok":
p = Path(payload)
if p.exists():
good_paths.append(p)
else:
errs.append(("?", f"output missing: {p}", ""))
else:
chrom, msg, tb = payload
errs.append((chrom, msg, tb))
# Serial path (debug/deterministic or single task)
if (max_workers is not None and max_workers <= 1) or len(tasks) == 1:
good_paths: list[Path] = []
errs: list[tuple[str, str, str]] = []
for t in tasks:
status, payload = _safe_process(t)
_collect(status, payload, good_paths, errs)
else:
# Parallel path
procs = min(max_workers or mp.cpu_count(), len(tasks))
logger.info(f"Using {procs} parallel workers for {len(tasks)} chrom folders")
good_paths: list[Path] = []
errs: list[tuple[str, str, str]] = []
with mp.Pool(processes=procs, maxtasksperchild=10) as pool:
for status, payload in pool.imap_unordered(
_safe_process, tasks, chunksize=1
):
_collect(status, payload, good_paths, errs)
if errs:
for chrom, msg, tb in errs:
logger.error("Worker error for %s: %s\n%s", chrom, msg, tb)
# Optionally: raise RuntimeError("One or more workers failed; see logs.")
return [str(p) for p in good_paths]
def dedup_trname_peakseq_weighted(
lf: pl.LazyFrame, seed: int = 42, outdir: str | None = None
) -> pl.LazyFrame:
"""
Remove duplicate pairings of TR + peak sequence, but keep the distribution of chromosomes as best as possible.
Use a seed so the results are reproducible
"""
# Normalize key dtypes
lf = lf.with_columns(
[
pl.col("chrom").cast(pl.Utf8),
pl.col("tr_name").cast(pl.Utf8),
pl.col("peak_sequence").fill_null("<NULL>").cast(pl.Utf8),
]
)
# --- BEFORE: counts/ratios (materialize tiny table)
pre_df = (
lf.group_by("chrom")
.len()
.with_columns((pl.col("len") / pl.col("len").sum()).alias("pre_ratio"))
.sort("chrom")
.collect()
)
# Expected #groups (must equal result rows)
exp_groups = lf.select(["tr_name", "peak_sequence"]).unique().collect().height
logger.info(f"Expected groups: {exp_groups}")
# Tiny weights table back to lazy
pre_lf = pre_df.lazy().select(["chrom", "pre_ratio"]).rename({"pre_ratio": "w"})
# Weighted random score = log(w) + Gumbel(0,1); tie-break by a stable hash
TWO64 = 18446744073709551616.0
eps = 1e-12
h_expr = (
pl.concat_str(
[
pl.lit(f"seed:{seed}"),
pl.col("tr_name"),
pl.col("peak_sequence"),
pl.col("chrom"),
],
separator="|",
)
.hash()
.cast(pl.UInt64)
)
u_expr = (h_expr.cast(pl.Float64) + 1.0) / pl.lit(TWO64)
u_expr = (
pl.when(u_expr < eps)
.then(eps)
.when(u_expr > 1 - eps)
.then(1 - eps)
.otherwise(u_expr)
)
logw_expr = (
pl.when(pl.col("w").is_null() | (pl.col("w") <= 0))
.then(eps)
.otherwise(pl.col("w"))
.log()
)
gumbel_expr = -(-u_expr.log()).log()
score_expr = (logw_expr + gumbel_expr).alias("_score")
hash_expr = h_expr.alias("_h")
# Attach weights & scores, globally sort, then unique on the keys (keep first)
lf_sorted = (
lf.join(pre_lf, on="chrom", how="left")
.with_columns([score_expr, hash_expr])
.sort(["_score", "_h"], descending=[True, False])
)
lf_sel = lf_sorted.unique(subset=["tr_name", "peak_sequence"], keep="first").drop(
["w", "_score", "_h"]
)
# --- AFTER: counts/ratios + save
post_df = (
lf_sel.group_by("chrom")
.len()
.with_columns((pl.col("len") / pl.col("len").sum()).alias("post_ratio"))
.sort("chrom")
.collect(streaming=True)
)
compare_df = (
(
pre_df.select(["chrom", "len", "pre_ratio"])
.rename({"len": "pre_n"})
.join(
post_df.select(["chrom", "len", "post_ratio"]).rename(
{"len": "post_n"}
),
on="chrom",
how="full",
)
.fill_null(0)
.with_columns(
(pl.col("post_ratio") - pl.col("pre_ratio")).abs().alias("abs_delta"),
(
100
* (pl.col("post_ratio") - pl.col("pre_ratio"))
/ pl.col("pre_ratio")
)
.abs()
.alias("pcnt_delta"),
)
.sort("chrom")
)
.to_pandas()
.drop(columns=["chrom_right"])
)
# --- Sanity: must keep exactly one per group
got_rows = lf_sel.select(pl.len()).collect()["len"][0]
if got_rows != exp_groups:
# optional: raise or just log
logger.warning(
f"Dedup cardinality mismatch: expected {exp_groups}, got {got_rows}"
)
return lf_sel, compare_df
def write_map(lf: pl.LazyFrame, out_path: str, key: str, val: str, outname: str):
"""
Write the ID maps we created, spanning all the data. Will be called for:
- tr_seqid to tr_sequence
- peak_seqid to peak_sequence
- dna_seqid to dna_sequence
"""
maps_dir = Path(out_path).parent / "maps"
maps_dir.mkdir(parents=True, exist_ok=True)
df = lf.select([pl.col(key), pl.col(val)]).unique().collect(streaming=True)
mapping = dict(zip(df[key].to_list(), df[val].to_list()))
with open(maps_dir / outname, "w") as f:
json.dump(mapping, f, indent=2)
def combine_processed_with_polars(
paths_to_processed_dfs: list[str],
idmap_path: str, # TSV with columns: From, Entry, Sequence
out_path: str, # e.g., "processed_out.parquet" or ".csv"
max_protein_len: int = None,
check_violations: bool = False,
seeds: list = [0],
):
if not paths_to_processed_dfs:
logger.info("No records produced; nothing to write.")
return
# 1) Scan each CSV and normalize dtypes BEFORE concat
lfs = []
for p in paths_to_processed_dfs:
lf_i = pl.scan_csv(p).with_columns( # don't use infer_schema_length=0 here
[
pl.col("chrom").cast(pl.Utf8),
pl.col("tr_name").cast(pl.Utf8),
pl.col("dna_sequence").cast(pl.Utf8),
pl.col("peak_sequence").cast(pl.Utf8),
pl.col("scores").cast(pl.Utf8),
pl.col("chipscore").cast(pl.Float64), # robust (int -> float OK)
pl.col("total_jaspar_hits").cast(pl.Int64),
]
)
lfs.append(lf_i)
# 2) Now concat; schemas match
lf_og = pl.concat(lfs, how="vertical")
# Read idmap to get list of unmapped TRs, those with no sequence
idmap = pl.read_csv(
idmap_path, separator="\t", columns=["From", "Entry", "Sequence"]
).rename({"From": "tr_name", "Entry": "tr_uniprot", "Sequence": "tr_sequence"})
idmap = idmap.with_columns(
pl.col("tr_sequence")
.map_elements(lambda x: len(x), return_dtype=pl.Int64)
.alias("tr_len")
)
if max_protein_len is not None:
idmap = idmap.filter(pl.col("tr_len") <= max_protein_len)
logger.info(f"Filtered valid TRs to only those with len <= {max_protein_len}")
success_trs = list(idmap["tr_name"].unique())
logger.info(f"Total valid TRs: {len(success_trs)}")
# Filter lf to only have "success-TRs"
lf_og = lf_og.filter(pl.col("tr_name").is_in(success_trs))
# 2) We COULD drop duplicate occurrences of tr_name and peak_sequence, because these are the same peak. But here we're showing them in different contexts
# Instead, let's drop duplicate occurrences of the same tr_name and dna_sequence, because these are duplicate datapoints.
lf_out = None # the last one will be used to save the example file
out_path = str(out_path)
Path(out_path).parent.mkdir(parents=True, exist_ok=True)
lf_og = lf_og.join(idmap.lazy(), on="tr_name", how="left")
logger.info(f"Merged in UniProt IDs and TR sequences from UniProt ID mappping")
# 4) Per-chromosome unique peak index and peak_id
# (dense rank over peak_sequence per chrom; if you require "first-appearance" order,
# see the note below for an alternate approach.)
# Ensure types
lf_og = lf_og.with_columns(
[
pl.col("dna_sequence").cast(pl.Utf8),
pl.col("tr_sequence").cast(pl.Utf8),
pl.col("peak_sequence").cast(pl.Utf8),
]
)
# set chrom+ peak sequence IDs
lf_og = lf_og.with_columns(
[
pl.col("peak_sequence").fill_null("").alias("peak_sequence"),
pl.col("chrom").cast(pl.Utf8),
]
)
lf_og = lf_og.with_columns(
pl.col("peak_sequence")
.rank(method="dense") # 1,2,3,... per group
.over("chrom")
.cast(pl.Int64)
.alias("chrom_peak_idx")
)
lf_og = lf_og.with_columns(
pl.format("chr{}_peak{}", pl.col("chrom"), pl.col("chrom_peak_idx")).alias(
"chrpeak_id"
)
)
logger.info(f"Assigned unique chrpeak_ids per chromosome based on peak_sequence")
# 5) Build stable IDs for dna_sequence and tr_sequence based on first appearance
# (do this by creating small maps with unique(..., maintain_order=True) and joining)
# Sequence-based IDs without any joins
lf_og = (
lf_og.with_columns(
[
pl.col("dna_sequence")
.rank(method="dense")
.cast(pl.Int64)
.alias("dna_idx"),
pl.col("tr_sequence")
.rank(method="dense")
.cast(pl.Int64)
.alias("tr_idx"),
pl.col("peak_sequence")
.rank(method="dense")
.cast(pl.Int64)
.alias("peak_idx"),
]
)
.with_columns(
[
pl.format("dnaseq{}", pl.col("dna_idx")).alias("dna_seqid"),
pl.format("trseq{}", pl.col("tr_idx")).alias("tr_seqid"),
pl.format("peakseq{}", pl.col("peak_idx")).alias("peak_seqid"),
]
)
.drop(["dna_idx", "tr_idx", "peak_idx"])
)
logger.info(f"Assigned unique dna IDs, transcriptional regulator IDs, and peak IDs")
# Final ID (will never be None now)
lf_og = lf_og.with_columns(
pl.concat_str(
[pl.col("tr_seqid"), pl.lit("_"), pl.col("dna_seqid")], ignore_nulls=False
).alias("ID")
)
# Write the maps
# call it for each mapping
write_map(
lf_og,
out_path=out_path,
val="tr_sequence",
key="tr_seqid",
outname="tr_seqid_to_tr_sequence.json",
)
write_map(
lf_og,
out_path=out_path,
val="peak_sequence",
key="peak_seqid",
outname="peak_seqid_to_peak_sequence.json",
)
write_map(
lf_og,
out_path=out_path,
val="dna_sequence",
key="dna_seqid",
outname="dna_seqid_to_dna_sequence.json",
)
for seed in seeds:
# edit out path to include seed
if "." in out_path:
out_path_full = (
out_path[0 : out_path.rindex(".")]
+ f"_seed{seed}"
+ out_path[out_path.rindex(".") :]
)
else:
out_path_full = out_path + f"_seed{seed}.parquet"
lf, compare_df = dedup_trname_peakseq_weighted(lf_og, seed=seed)
logger.info(
f"Dropped duplicate examples of tr_name + peak_sequence. Maintained chrom distribution with weighted random sampling (seed={seed})."
)
# Save comparison df. Annotate with debug if it's a debug run
compare_df_path = str(
Path(out_path).parent / f"chrom_ratio_compare_seed{seed}.csv"
)
if "debug" in out_path:
compare_df_path = compare_df_path.replace(".csv", "_debug.csv")
compare_df.to_csv(compare_df_path, index=False)
# 3) Join small idmap (read eagerly; it’s tiny)
lf = lf.join(idmap.lazy(), on="tr_name", how="left")
logger.info(f"Merged in UniProt IDs and TR sequences from UniProt ID mappping")
logger.info(f"Applied dna_sequence and tr_sequence IDs to main table")
# Each sequence maps to exactly one id
viol1 = (
lf.select("dna_sequence", "dna_seqid")
.unique()
.group_by("dna_sequence")
.agg(pl.n_unique("dna_seqid").alias("n_ids"))
.filter(pl.col("n_ids") > 1)
.collect()
)
# Each id maps to exactly one sequence
viol2 = (
lf.select("dna_sequence", "dna_seqid")
.unique()
.group_by("dna_seqid")
.agg(pl.n_unique("dna_sequence").alias("n_seqs"))
.filter(pl.col("n_seqs") > 1)
.collect()
)
logger.info(
"viol1 rows (seq→>1 id): %d; viol2 rows (id→>1 seq): %d",
viol1.height,
viol2.height,
)
# No NULLs
nulls = lf.select(
[
pl.col("dna_seqid").is_null().sum().alias("null_dna_seqid"),
pl.col("tr_seqid").is_null().sum().alias("null_tr_seqid"),
pl.col("ID").is_null().sum().alias("null_ID"),
]
).collect()
logger.info("NULL counts:\n%s", nulls)
# 6) Final column selection
cols = [
"ID",
"tr_seqid",
"dna_seqid",
"peak_seqid",
"chrpeak_id",
"tr_name",
"chipscore",
"total_jaspar_hits",
"dna_sequence",
"tr_sequence",
"scores",
]
lf_out = lf.select(cols)
# n_rows = lf_out.select(pl.len().alias("rows")).collect(streaming=True)["rows"][0]
logger.info(f"Selected final columns")
# 7) Write streaming to disk
if out_path_full.lower().endswith(".parquet"):
lf_out.sink_parquet(
out_path_full,
compression="zstd",
statistics=True,
row_group_size=128_000,
)
logger.info(f"Wrote parquet file to {out_path_full}")
elif out_path_full.lower().endswith(".csv"):
# NOTE: collect(streaming=True) still returns an in-memory DataFrame;
# prefer Parquet for very large outputs.
lf_out.collect(streaming=True).write_csv(out_path_full)
logger.info(f"Wrote csv file to {out_path_full}")
else:
# default to Parquet if no/unknown extension
lf_out.sink_parquet(
out_path_full + ".parquet", compression="zstd", statistics=True
)
logger.info(f"Wrote parquet file to {out_path_full}")
# Save the FIRST 1000 rows to CSV (streaming-friendly)
df_first = lf_out.limit(1000).collect(streaming=True)
example_out_path = (
Path(root)
/ "dpacman/data_files/processed/remap/examples"
/ "example1000_remap2022_crm_fimo_output_q_processed.csv"
)
df_first.write_csv(example_out_path)
logger.info(f"Wrote first 1000 rows to {example_out_path} as an example")
# FIMO check
def get_reverse_complement(s):
"""
Returns 5' to 3' sequence of the reverse complement
"""
chars = list(s)
recon = []
rev_map = {
"a": "t",
"c": "g",
"t": "a",
"g": "c",
"A": "T",
"C": "G",
"T": "A",
"G": "C",
"n": "n",
"N": "N",
}
for c in chars:
recon += [rev_map[c]]
recon = "".join(recon)
return recon[::-1]
def extract_jaspar_motifs(row, reverse_complement=False):
s = row["scores"]
s = [int(x) for x in s.split(",")]
n_motifs = row["total_jaspar_hits"]
if n_motifs == 0:
return ""
chipscore = row["chipscore"]
dna_seq = row["dna_sequence"]
if reverse_complement:
dna_seq = row["dna_sequence_rc"]
jaspar_indices = [i for i in list(range(len(s))) if s[i] > chipscore]
pred_motif = ""
for i in list(range(jaspar_indices[0], jaspar_indices[-1] + 1)):
if not (i in jaspar_indices):
pred_motif += "-"
else:
pred_motif += dna_seq[i]
return pred_motif
def clean_idmap(idmap_path):
"""
The raw ID Map from UniProt returned multiple results.
We went to ReMap and wrote down what the right mappings are in these cases.
"""
manual_map = {
"BACH1": "O14867",
"BAP1": "Q92560",
"BDP1": "A6H8Y1",
"BRF1": "Q92994",
"CUX1": "Q13948",
"DDX21": "Q9NR30",
"ERG": "P11308",
"HBP1": "O60381",
"KLF14": "Q8TD94",
"MED1": "Q15648",
"MED25": "Q71SY5",
"MGA": "Q8IWI9",
"NRF1": "Q16656",
"PAF1": "Q8N7H5",
"PDX1": "P52945",
"RBP2": "P50120",
"RLF": "Q13129",
"SP1": "P08047",
"SPIN1": "Q9Y657",
"STAG1": "Q8WVM7",
"TAF15": "Q92804",
"TCF3": "P15923",
"ZFP36": "P26651",
"EVI1": "Q03112",
"MCM2": "P49736",
}
idmap = pd.read_csv(idmap_path, sep="\t")
idmap["Remap_Entry"] = idmap.apply(
lambda row: (
row["Entry"] if not (row["From"] in manual_map) else manual_map[row["From"]]
),
axis=1,
)
idmap_remapped = (
idmap.loc[idmap["Entry"] == idmap["Remap_Entry"]]
.reset_index(drop=True)
.drop(columns=["Remap_Entry"])
)
assert len(idmap_remapped) == len(idmap_remapped["From"].unique())
logger.info(
f"Total transcriptional regulators successfully mapped in UniProt: {len(idmap_remapped)}"
)
clean_idmap_path = (
Path(root)
/ "dpacman/data_files/processed/remap/idmapping_reviewed_true_processed_2025_08_11.tsv"
)
idmap_remapped.to_csv(clean_idmap_path, sep="\t")
return clean_idmap_path
def debug_fimo_check(
path_to_chrom_fimo, path_to_processed_chrom, chrom="Y", json_dir=""
):
"""
Make sure we are properly extracting fimo sequences.
"""
processed = pd.read_csv(path_to_processed_chrom)
processed["pred_motif_string"] = processed.apply(
lambda row: extract_jaspar_motifs(row), axis=1
)
processed["dna_sequence_rc"] = processed["dna_sequence"].apply(
lambda x: get_reverse_complement(x)
)
processed["pred_motif_string_rc"] = processed.apply(
lambda row: extract_jaspar_motifs(row, reverse_complement=True), axis=1
)
processed_trs = processed["tr_name"].unique().tolist()
fimo = pd.read_csv(path_to_chrom_fimo)
fimo["input_tr"] = fimo["sequence_name"].str.split("_", expand=True)[2]
fimo_valid = fimo.loc[
(fimo["motif_alt_id"] == fimo["input_tr"])
& (fimo["motif_alt_id"].isin(processed_trs))
].reset_index(drop=True)
logger.info(f"Total valid FIMO matches: {len(fimo_valid)}")
logger.info(
f"Total transcriptional regulators being considered: {len(processed_trs)}"
)
# Load DNA
cache_debug = load_chrom_dna(chrom, {}, json_dir=json_dir)
# Randomly select a positive and negative row to test
pos_row = fimo_valid.loc[fimo_valid["strand"] == "+"].sample(n=1, random_state=44)
neg_row = fimo_valid.loc[fimo_valid["strand"] == "-"].sample(n=1, random_state=44)
# Iterate through the rows
for row in [pos_row, neg_row]:
indices = [int(x) for x in row["sequence_name"].item().split("_")[-2::]]
chipseq_start, chipseq_end = indices
strand = row["strand"].item()
logger.info(f"ChIPseq start: {chipseq_start}, ChIPseq end: {chipseq_end}")
logger.info(f"Strand: {strand}")
motif_start = chipseq_start + int(row["start"].item()) - 1
motif_end = chipseq_start + int(row["stop"].item())
motif = row["matched_sequence"].item()
full_seq = cache_debug[chipseq_start:chipseq_end].upper()
logger.info(f"Full sequence: {full_seq}")
logger.info(
f"Full sequence reverse complement: {get_reverse_complement(full_seq)}"
)
our_motif = cache_debug[motif_start:motif_end].upper()
if strand == "+":
logger.info(f"True motif found by FIMO: {motif}")
logger.info(f"Extracted motif on our end: {our_motif}")
logger.info(f"Correct extraction: {motif==our_motif}")
matching_rows = processed.loc[
(processed["dna_sequence"].str.contains(full_seq))
& (processed["pred_motif_string"].str.contains(our_motif))
]
if strand == "-":
our_motif_rc = get_reverse_complement(our_motif)
logger.info(f"True motif found by FIMO: {motif}")
logger.info(f"Extracted motif on our end: {our_motif_rc}")
logger.info(f"Correct extraction: {motif==our_motif_rc}")
logger.info(f"Motif that will appear in the forward sequence: {our_motif}")
matching_rows = processed.loc[
(processed["dna_sequence"].str.contains(full_seq))
& (processed["pred_motif_string"].str.contains(our_motif))
]
# Now find if there are rows with the same TR, and the same DNA sequence and motif
matching_row_trs = sorted(matching_rows["tr_name"].unique().tolist())
expected_tr = row["motif_alt_id"].item()
logger.info(f"TR from selected row: {expected_tr}")
logger.info(f"TRs with same motif: {','.join(matching_row_trs)}")
logger.info(f"Expected TR in list: {expected_tr in matching_row_trs}")
def debug_remap_check(remap_path, path_to_processed_chrom, chrom="Y", json_dir=""):
"""
For debugging mode: pick a random row from processed remap. make sure the sequence matches the one we're getting here.
"""
remap = pd.read_csv(remap_path)
remap["ChIPStart"] = remap["ChIPStart"].astype(int)
remap["ChIPEnd"] = remap["ChIPEnd"].astype(int)
row = remap.loc[remap["#chrom"] == "Y"].sample(n=1, random_state=42)
start, end = row["ChIPStart"].item(), row["ChIPEnd"].item()
cache_debug = load_chrom_dna(chrom, {}, json_dir=json_dir)
test_seq = cache_debug[start:end].upper()
logger.info(
f"Randomly sampled sequence ({len(test_seq)} nucleotides), chrY {start}:{end}\n\tsequence: {test_seq}"
)
should_find = (
remap.loc[(remap["ChIPStart"] == start) & (remap["ChIPEnd"] == end)]["TR"]
.unique()
.tolist()
)
logger.info(
f"Expect to find {len(should_find)} TRs: {', '.join(sorted(should_find))}"
)
processed = pd.read_csv(path_to_processed_chrom)
did_find = (
processed.loc[processed["peak_sequence"] == test_seq]["tr_name"]
.unique()
.tolist()
)
logger.info(
f"Looked up same sequence in processed chrY file.\nFound TRs: {', '.join(sorted(did_find))}"
)
logger.info(f"found==expected: {did_find==should_find}")
def main(cfg: DictConfig):
debug = bool(cfg.data_task.debug)
json_dir = cfg.data_task.json_dir
fimo_out_dir = Path(root) / cfg.data_task.fimo_out_dir
processed_output_csv = Path(root) / cfg.data_task.processed_output_csv
output_parts_folder = processed_output_csv.parent / "temp_parts"
max_workers = getattr(cfg.data_task, "max_workers", None)
logger.info(f"Debug: {debug}")
logger.info(f"Reading per-chrom final.csv under: {fimo_out_dir}")
# process the idmap
idmap_path = Path(root) / cfg.data_task.idmap_path
clean_idmap_path = clean_idmap(idmap_path)
# If we don't have temp files to process
if not (os.path.exists(output_parts_folder)) or (
os.path.exists(output_parts_folder)
and len(os.listdir(output_parts_folder)) < 24
):
paths_to_processed_dfs = build_dataset_fast_mp(
fimo_out_dir=fimo_out_dir,
json_dir=json_dir,
debug=debug,
max_workers=max_workers,
jaspar_boost=cfg.data_task.jaspar_boost,
output_parts_folder=output_parts_folder,
keep_fimo_only=cfg.data_task.keep_fimo_only,
)
else:
paths_to_processed_dfs = (
[output_parts_folder / x for x in os.listdir(output_parts_folder)]
if output_parts_folder.exists()
else []
)
# Debug methods: (1) make sure our peak sequences correspond to remap, (2) make sure our FIMO sequences correspond to FIMO results
out_path = str(processed_output_csv).replace(".csv", ".parquet")
if debug:
debug_remap_check(
remap_path=Path(root) / cfg.data_task.remap_path,
path_to_processed_chrom=Path(output_parts_folder) / "chromY_processed.csv",
chrom="Y",
json_dir=json_dir,
)
debug_fimo_check(
path_to_chrom_fimo=Path(root)
/ cfg.data_task.fimo_out_dir
/ "chromY"
/ "fimo_annotations.csv",
path_to_processed_chrom=Path(output_parts_folder) / "chromY_processed.csv",
chrom="Y",
json_dir=json_dir,
)
out_path = out_path.replace(".parquet", "_debug.parquet")
logger.info(f"Combining {len(paths_to_processed_dfs)} processed parts with Polars")
combine_processed_with_polars(
paths_to_processed_dfs=paths_to_processed_dfs,
idmap_path=clean_idmap_path,
out_path=out_path,
max_protein_len=cfg.data_task.max_protein_len,
seeds=cfg.data_task.seeds,
)
# Delete the folder that had the temporary DFs, don't need these
if False and output_parts_folder.exists():
for f in output_parts_folder.glob("*.csv"):
f.unlink()
output_parts_folder.rmdir()
logger.info(f"Cleaned up temporary files in {output_parts_folder}")
if __name__ == "__main__":
# On some clusters with older Python, 'fork' is default and fine.
# If you hit issues (e.g., with threads/IO), uncomment spawn:
# mp.set_start_method("spawn", force=True)
main()