| """ |
| MSigDB Data Processing Pipeline for Contrastive Pretraining. |
| |
| Strategy: |
| 1. Download full GMT files from Broad data server (no auth needed) |
| 2. Fetch brief descriptions from MSigDB HTML card pages (no auth needed) |
| 3. Build text-gene paired dataset |
| |
| Usage: |
| python data_processing.py |
| """ |
|
|
| import json |
| import os |
| import time |
| import html |
| import re |
| import random |
| import concurrent.futures |
| from collections import defaultdict |
| import requests |
|
|
| BROAD_BASE = "https://data.broadinstitute.org/gsea-msigdb/msigdb/release" |
| VERSION = "2024.1" |
|
|
| HUMAN_GMTS = {"H": "h.all", "C1": "c1.all", "C2": "c2.all", "C3": "c3.all", |
| "C4": "c4.all", "C5": "c5.all", "C6": "c6.all", "C7": "c7.all", "C8": "c8.all"} |
| MOUSE_GMTS = {"MH": "mh.all", "M1": "m1.all", "M2": "m2.all", "M3": "m3.all", |
| "M5": "m5.all", "M8": "m8.all"} |
|
|
|
|
| def download_gmt(collection_code, species="Hs", version=VERSION): |
| gmts = HUMAN_GMTS if species == "Hs" else MOUSE_GMTS |
| prefix = gmts.get(collection_code) |
| if not prefix: |
| return [] |
| filename = f"{prefix}.v{version}.{species}.symbols.gmt" |
| url = f"{BROAD_BASE}/{version}.{species}/{filename}" |
| try: |
| resp = requests.get(url, timeout=60) |
| resp.raise_for_status() |
| except Exception as e: |
| print(f" Warning: {url}: {e}") |
| return [] |
| gene_sets = [] |
| for line in resp.text.strip().split("\\n"): |
| parts = line.split("\\t") |
| if len(parts) < 3: |
| continue |
| gene_sets.append({ |
| "name": parts[0].strip(), "url": parts[1].strip(), |
| "genes": [g.strip() for g in parts[2:] if g.strip()], |
| "collection": collection_code, |
| "species": "human" if species == "Hs" else "mouse", |
| }) |
| return gene_sets |
|
|
|
|
| def download_all_gmts(output_dir="data/raw"): |
| os.makedirs(output_dir, exist_ok=True) |
| all_gs = [] |
| print("Downloading human gene sets...") |
| for code in HUMAN_GMTS: |
| gs = download_gmt(code, "Hs") |
| all_gs.extend(gs) |
| print(f" {code}: {len(gs)}") |
| print("Downloading mouse gene sets...") |
| for code in MOUSE_GMTS: |
| gs = download_gmt(code, "Mm") |
| all_gs.extend(gs) |
| print(f" {code}: {len(gs)}") |
| with open(os.path.join(output_dir, "all_gmt_genesets.json"), "w") as f: |
| json.dump(all_gs, f) |
| print(f"Total: {len(all_gs)}") |
| return all_gs |
|
|
|
|
| def fetch_description_html(name, species="human"): |
| url = f"https://www.gsea-msigdb.org/gsea/msigdb/{species}/geneset/{name}" |
| try: |
| resp = requests.get(url, timeout=15) |
| resp.raise_for_status() |
| match = re.findall(r'Brief\\s+description.*?<td[^>]*>(.*?)</td>', resp.text, re.DOTALL | re.IGNORECASE) |
| if match: |
| desc = re.sub(r'<[^>]+>', '', match[0]).strip() |
| desc = html.unescape(desc) |
| if desc and desc.lower() not in ["na", "n/a"]: |
| return desc |
| except Exception: |
| pass |
| return "" |
|
|
|
|
| def fetch_descriptions_batch(gene_sets, max_workers=10, cache_path="data/raw/descriptions_cache.json"): |
| cache = {} |
| if os.path.exists(cache_path): |
| with open(cache_path) as f: |
| cache = json.load(f) |
| to_fetch = [(gs["name"], gs["species"]) for gs in gene_sets if gs["name"] not in cache] |
| print(f"Need to fetch {len(to_fetch)} descriptions ({len(cache)} cached)") |
| if to_fetch: |
| fetched = 0 |
| with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex: |
| futures = {ex.submit(fetch_description_html, n, s): (n, s) for n, s in to_fetch} |
| for f in concurrent.futures.as_completed(futures): |
| name, _ = futures[f] |
| cache[name] = f.result() |
| fetched += 1 |
| if fetched % 500 == 0: |
| with open(cache_path, "w") as fp: |
| json.dump(cache, fp) |
| print(f" {fetched}/{len(to_fetch)}") |
| with open(cache_path, "w") as f: |
| json.dump(cache, f) |
| return cache |
|
|
|
|
| def build_pairs(gene_sets, descriptions, min_genes=5, max_genes=2000): |
| pairs = [] |
| for gs in gene_sets: |
| genes = gs["genes"] |
| if len(genes) < min_genes or len(genes) > max_genes: |
| continue |
| desc = descriptions.get(gs["name"], "") |
| parts = [f"[Collection: {gs['collection']}] [Species: {gs['species']}]", |
| gs["name"].replace("_", " ")] |
| if desc: |
| parts.append(html.unescape(re.sub(r'<[^>]+>', ' ', desc)).strip()) |
| text = "\\n".join(parts) |
| if len(text) < 30: |
| continue |
| pairs.append({"id": gs["name"], "text": text, "genes": genes, |
| "n_genes": len(genes), "collection": gs["collection"], |
| "species": gs["species"], "has_description": bool(desc)}) |
| return pairs |
|
|
|
|
| def split_and_save(pairs, output_dir="data/processed"): |
| os.makedirs(output_dir, exist_ok=True) |
| train_cols = {"C2", "C5", "C8", "C1", "M2", "M5", "M8", "M1"} |
| val_cols = {"C3", "C4", "M3"} |
| test_cols = {"H", "C6", "C7", "MH"} |
| splits = {"train": [], "val": [], "test": []} |
| for p in pairs: |
| c = p["collection"] |
| if c in train_cols: splits["train"].append(p) |
| elif c in val_cols: splits["val"].append(p) |
| elif c in test_cols: splits["test"].append(p) |
| else: splits["train"].append(p) |
| for name, data in splits.items(): |
| path = os.path.join(output_dir, f"{name}.jsonl") |
| with open(path, "w") as f: |
| for r in data: |
| f.write(json.dumps(r) + "\\n") |
| print(f"{name}: {len(data)} pairs -> {path}") |
| return splits |
|
|
|
|
| if __name__ == "__main__": |
| all_gs = download_all_gmts() |
| descs = fetch_descriptions_batch(all_gs) |
| pairs = build_pairs(all_gs, descs) |
| print(f"\\nTotal pairs: {len(pairs)}") |
| split_and_save(pairs) |
|
|