""" MSigDB Data Processing Pipeline for Contrastive Pretraining. Strategy: 1. Download full GMT files from Broad data server (no auth needed) 2. Fetch brief descriptions from MSigDB HTML card pages (no auth needed) 3. Build text-gene paired dataset Usage: python data_processing.py """ import json import os import time import html import re import random import concurrent.futures from collections import defaultdict import requests BROAD_BASE = "https://data.broadinstitute.org/gsea-msigdb/msigdb/release" VERSION = "2024.1" HUMAN_GMTS = {"H": "h.all", "C1": "c1.all", "C2": "c2.all", "C3": "c3.all", "C4": "c4.all", "C5": "c5.all", "C6": "c6.all", "C7": "c7.all", "C8": "c8.all"} MOUSE_GMTS = {"MH": "mh.all", "M1": "m1.all", "M2": "m2.all", "M3": "m3.all", "M5": "m5.all", "M8": "m8.all"} def download_gmt(collection_code, species="Hs", version=VERSION): gmts = HUMAN_GMTS if species == "Hs" else MOUSE_GMTS prefix = gmts.get(collection_code) if not prefix: return [] filename = f"{prefix}.v{version}.{species}.symbols.gmt" url = f"{BROAD_BASE}/{version}.{species}/{filename}" try: resp = requests.get(url, timeout=60) resp.raise_for_status() except Exception as e: print(f" Warning: {url}: {e}") return [] gene_sets = [] for line in resp.text.strip().split("\\n"): parts = line.split("\\t") if len(parts) < 3: continue gene_sets.append({ "name": parts[0].strip(), "url": parts[1].strip(), "genes": [g.strip() for g in parts[2:] if g.strip()], "collection": collection_code, "species": "human" if species == "Hs" else "mouse", }) return gene_sets def download_all_gmts(output_dir="data/raw"): os.makedirs(output_dir, exist_ok=True) all_gs = [] print("Downloading human gene sets...") for code in HUMAN_GMTS: gs = download_gmt(code, "Hs") all_gs.extend(gs) print(f" {code}: {len(gs)}") print("Downloading mouse gene sets...") for code in MOUSE_GMTS: gs = download_gmt(code, "Mm") all_gs.extend(gs) print(f" {code}: {len(gs)}") with open(os.path.join(output_dir, "all_gmt_genesets.json"), "w") as f: json.dump(all_gs, f) print(f"Total: {len(all_gs)}") return all_gs def fetch_description_html(name, species="human"): url = f"https://www.gsea-msigdb.org/gsea/msigdb/{species}/geneset/{name}" try: resp = requests.get(url, timeout=15) resp.raise_for_status() match = re.findall(r'Brief\\s+description.*?]*>(.*?)', resp.text, re.DOTALL | re.IGNORECASE) if match: desc = re.sub(r'<[^>]+>', '', match[0]).strip() desc = html.unescape(desc) if desc and desc.lower() not in ["na", "n/a"]: return desc except Exception: pass return "" def fetch_descriptions_batch(gene_sets, max_workers=10, cache_path="data/raw/descriptions_cache.json"): cache = {} if os.path.exists(cache_path): with open(cache_path) as f: cache = json.load(f) to_fetch = [(gs["name"], gs["species"]) for gs in gene_sets if gs["name"] not in cache] print(f"Need to fetch {len(to_fetch)} descriptions ({len(cache)} cached)") if to_fetch: fetched = 0 with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex: futures = {ex.submit(fetch_description_html, n, s): (n, s) for n, s in to_fetch} for f in concurrent.futures.as_completed(futures): name, _ = futures[f] cache[name] = f.result() fetched += 1 if fetched % 500 == 0: with open(cache_path, "w") as fp: json.dump(cache, fp) print(f" {fetched}/{len(to_fetch)}") with open(cache_path, "w") as f: json.dump(cache, f) return cache def build_pairs(gene_sets, descriptions, min_genes=5, max_genes=2000): pairs = [] for gs in gene_sets: genes = gs["genes"] if len(genes) < min_genes or len(genes) > max_genes: continue desc = descriptions.get(gs["name"], "") parts = [f"[Collection: {gs['collection']}] [Species: {gs['species']}]", gs["name"].replace("_", " ")] if desc: parts.append(html.unescape(re.sub(r'<[^>]+>', ' ', desc)).strip()) text = "\\n".join(parts) if len(text) < 30: continue pairs.append({"id": gs["name"], "text": text, "genes": genes, "n_genes": len(genes), "collection": gs["collection"], "species": gs["species"], "has_description": bool(desc)}) return pairs def split_and_save(pairs, output_dir="data/processed"): os.makedirs(output_dir, exist_ok=True) train_cols = {"C2", "C5", "C8", "C1", "M2", "M5", "M8", "M1"} val_cols = {"C3", "C4", "M3"} test_cols = {"H", "C6", "C7", "MH"} splits = {"train": [], "val": [], "test": []} for p in pairs: c = p["collection"] if c in train_cols: splits["train"].append(p) elif c in val_cols: splits["val"].append(p) elif c in test_cols: splits["test"].append(p) else: splits["train"].append(p) for name, data in splits.items(): path = os.path.join(output_dir, f"{name}.jsonl") with open(path, "w") as f: for r in data: f.write(json.dumps(r) + "\\n") print(f"{name}: {len(data)} pairs -> {path}") return splits if __name__ == "__main__": all_gs = download_all_gmts() descs = fetch_descriptions_batch(all_gs) pairs = build_pairs(all_gs, descs) print(f"\\nTotal pairs: {len(pairs)}") split_and_save(pairs)