File size: 5,873 Bytes
06fa7f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
MSigDB Data Processing Pipeline for Contrastive Pretraining.

Strategy:
1. Download full GMT files from Broad data server (no auth needed)
2. Fetch brief descriptions from MSigDB HTML card pages (no auth needed)
3. Build text-gene paired dataset

Usage:
    python data_processing.py
"""

import json
import os
import time
import html
import re
import random
import concurrent.futures
from collections import defaultdict
import requests

BROAD_BASE = "https://data.broadinstitute.org/gsea-msigdb/msigdb/release"
VERSION = "2024.1"

HUMAN_GMTS = {"H": "h.all", "C1": "c1.all", "C2": "c2.all", "C3": "c3.all",
              "C4": "c4.all", "C5": "c5.all", "C6": "c6.all", "C7": "c7.all", "C8": "c8.all"}
MOUSE_GMTS = {"MH": "mh.all", "M1": "m1.all", "M2": "m2.all", "M3": "m3.all",
              "M5": "m5.all", "M8": "m8.all"}


def download_gmt(collection_code, species="Hs", version=VERSION):
    gmts = HUMAN_GMTS if species == "Hs" else MOUSE_GMTS
    prefix = gmts.get(collection_code)
    if not prefix:
        return []
    filename = f"{prefix}.v{version}.{species}.symbols.gmt"
    url = f"{BROAD_BASE}/{version}.{species}/{filename}"
    try:
        resp = requests.get(url, timeout=60)
        resp.raise_for_status()
    except Exception as e:
        print(f"  Warning: {url}: {e}")
        return []
    gene_sets = []
    for line in resp.text.strip().split("\\n"):
        parts = line.split("\\t")
        if len(parts) < 3:
            continue
        gene_sets.append({
            "name": parts[0].strip(), "url": parts[1].strip(),
            "genes": [g.strip() for g in parts[2:] if g.strip()],
            "collection": collection_code,
            "species": "human" if species == "Hs" else "mouse",
        })
    return gene_sets


def download_all_gmts(output_dir="data/raw"):
    os.makedirs(output_dir, exist_ok=True)
    all_gs = []
    print("Downloading human gene sets...")
    for code in HUMAN_GMTS:
        gs = download_gmt(code, "Hs")
        all_gs.extend(gs)
        print(f"  {code}: {len(gs)}")
    print("Downloading mouse gene sets...")
    for code in MOUSE_GMTS:
        gs = download_gmt(code, "Mm")
        all_gs.extend(gs)
        print(f"  {code}: {len(gs)}")
    with open(os.path.join(output_dir, "all_gmt_genesets.json"), "w") as f:
        json.dump(all_gs, f)
    print(f"Total: {len(all_gs)}")
    return all_gs


def fetch_description_html(name, species="human"):
    url = f"https://www.gsea-msigdb.org/gsea/msigdb/{species}/geneset/{name}"
    try:
        resp = requests.get(url, timeout=15)
        resp.raise_for_status()
        match = re.findall(r'Brief\\s+description.*?<td[^>]*>(.*?)</td>', resp.text, re.DOTALL | re.IGNORECASE)
        if match:
            desc = re.sub(r'<[^>]+>', '', match[0]).strip()
            desc = html.unescape(desc)
            if desc and desc.lower() not in ["na", "n/a"]:
                return desc
    except Exception:
        pass
    return ""


def fetch_descriptions_batch(gene_sets, max_workers=10, cache_path="data/raw/descriptions_cache.json"):
    cache = {}
    if os.path.exists(cache_path):
        with open(cache_path) as f:
            cache = json.load(f)
    to_fetch = [(gs["name"], gs["species"]) for gs in gene_sets if gs["name"] not in cache]
    print(f"Need to fetch {len(to_fetch)} descriptions ({len(cache)} cached)")
    if to_fetch:
        fetched = 0
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
            futures = {ex.submit(fetch_description_html, n, s): (n, s) for n, s in to_fetch}
            for f in concurrent.futures.as_completed(futures):
                name, _ = futures[f]
                cache[name] = f.result()
                fetched += 1
                if fetched % 500 == 0:
                    with open(cache_path, "w") as fp:
                        json.dump(cache, fp)
                    print(f"  {fetched}/{len(to_fetch)}")
        with open(cache_path, "w") as f:
            json.dump(cache, f)
    return cache


def build_pairs(gene_sets, descriptions, min_genes=5, max_genes=2000):
    pairs = []
    for gs in gene_sets:
        genes = gs["genes"]
        if len(genes) < min_genes or len(genes) > max_genes:
            continue
        desc = descriptions.get(gs["name"], "")
        parts = [f"[Collection: {gs['collection']}] [Species: {gs['species']}]",
                 gs["name"].replace("_", " ")]
        if desc:
            parts.append(html.unescape(re.sub(r'<[^>]+>', ' ', desc)).strip())
        text = "\\n".join(parts)
        if len(text) < 30:
            continue
        pairs.append({"id": gs["name"], "text": text, "genes": genes,
                      "n_genes": len(genes), "collection": gs["collection"],
                      "species": gs["species"], "has_description": bool(desc)})
    return pairs


def split_and_save(pairs, output_dir="data/processed"):
    os.makedirs(output_dir, exist_ok=True)
    train_cols = {"C2", "C5", "C8", "C1", "M2", "M5", "M8", "M1"}
    val_cols = {"C3", "C4", "M3"}
    test_cols = {"H", "C6", "C7", "MH"}
    splits = {"train": [], "val": [], "test": []}
    for p in pairs:
        c = p["collection"]
        if c in train_cols: splits["train"].append(p)
        elif c in val_cols: splits["val"].append(p)
        elif c in test_cols: splits["test"].append(p)
        else: splits["train"].append(p)
    for name, data in splits.items():
        path = os.path.join(output_dir, f"{name}.jsonl")
        with open(path, "w") as f:
            for r in data:
                f.write(json.dumps(r) + "\\n")
        print(f"{name}: {len(data)} pairs -> {path}")
    return splits


if __name__ == "__main__":
    all_gs = download_all_gmts()
    descs = fetch_descriptions_batch(all_gs)
    pairs = build_pairs(all_gs, descs)
    print(f"\\nTotal pairs: {len(pairs)}")
    split_and_save(pairs)