Spaces:
Running
Running
| import os | |
| import json | |
| import time | |
| import socket | |
| import threading | |
| import requests | |
| import pyarrow.parquet as pq | |
| import gc | |
| from pathlib import Path | |
| from huggingface_hub import HfApi | |
| # ββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| RAW_DIR = "/data/raw" | |
| STATE_FILE = "/data/state.json" | |
| WORKER_TIMEOUT = 700 | |
| MAX_BUFFERED = 999999 | |
| os.makedirs(RAW_DIR, exist_ok=True) | |
| api = HfApi(token=HF_TOKEN) | |
| AUTH_HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"} | |
| # ββ Sources βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SOURCES = [ | |
| { | |
| "name" : "fineweb", | |
| "type" : "hf_list", | |
| "repo" : "HuggingFaceFW/fineweb-edu", | |
| "prefix" : "data/CC-MAIN-2025-26", | |
| "skip" : 5, | |
| "take" : 10, | |
| "text_col": "text", | |
| }, | |
| { | |
| "name" : "wikipedia", | |
| "type" : "hf_list", | |
| "repo" : "wikimedia/wikipedia", | |
| "prefix" : "20231101.en/train-", | |
| "skip" : 2, | |
| "take" : 18, | |
| "text_col": "text", | |
| }, | |
| { | |
| "name" : "openwebmath", | |
| "type" : "hf_list", | |
| "repo" : "open-web-math/open-web-math", | |
| "prefix" : "data/train-", | |
| "skip" : 0, | |
| "take" : 6, | |
| "text_col": "text", | |
| }, | |
| { | |
| "name" : "code", | |
| "type" : "url_list", | |
| "text_col": "text", | |
| "fmt" : "jsonl", | |
| "urls" : [ | |
| f"https://huggingface.co/buckets/Neon-tech/Dataset-arranger/resolve/by-language/{lang}/shard_{str(i).zfill(6)}.jsonl?download=true" | |
| for lang in ["C", "C++", "Java", "Go", "Rust", "Ruby", "PHP", "SQL", "C#", "Scala", "Lua", "Perl", "CSS"] | |
| for i in range(2) | |
| ], | |
| }, | |
| ] | |
| # ββ Keep-alive ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def serve(): | |
| s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
| s.bind(("0.0.0.0", 7860)) | |
| s.listen(5) | |
| print("β Listening on port 7860") | |
| while True: | |
| conn, _ = s.accept() | |
| conn.send(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK") | |
| conn.close() | |
| # ββ State βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_state(): | |
| if os.path.exists(STATE_FILE): | |
| with open(STATE_FILE) as f: | |
| state = json.load(f) | |
| shards = state["shards"] | |
| queue = state.get("queue", []) | |
| done = sum(1 for v in shards.values() if v["status"] == "done") | |
| claimed = sum(1 for v in shards.values() if v["status"] == "claimed") | |
| pending = sum(1 for v in shards.values() if v["status"] == "pending") | |
| print(f"Resuming β {done} done / {claimed} claimed / {pending} buffered / {len(queue)} queued") | |
| else: | |
| state = {"shards": {}, "queue": []} | |
| print("Starting fresh") | |
| return state | |
| def save_state(state): | |
| tmp = STATE_FILE + ".tmp" | |
| with open(tmp, "w") as f: | |
| json.dump(state, f, indent=2) | |
| os.replace(tmp, STATE_FILE) | |
| # ββ Discover ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def discover_all(state): | |
| known_urls = {v["url"] for v in state["shards"].values()} | {e["url"] for e in state.get("queue", [])} | |
| new_count = 0 | |
| for src in SOURCES: | |
| name = src["name"] | |
| print(f"\nDiscovering: {name}") | |
| if src["type"] == "hf_list": | |
| all_files = sorted([ | |
| f for f in api.list_repo_files(src["repo"], repo_type="dataset") | |
| if f.startswith(src["prefix"]) and f.endswith(".parquet") | |
| ]) | |
| selected = all_files[src["skip"]: src["skip"] + src["take"]] | |
| base_url = f"https://huggingface.co/datasets/{src['repo']}/resolve/main/" | |
| urls = [base_url + f for f in selected] | |
| fmt = "parquet" | |
| else: | |
| urls = src["urls"] | |
| fmt = src.get("fmt", "parquet") | |
| added = 0 | |
| for url in urls: | |
| if url not in known_urls: | |
| state["queue"].append({ | |
| "url" : url, | |
| "source" : name, | |
| "text_col" : src["text_col"], | |
| "fmt" : fmt, | |
| }) | |
| known_urls.add(url) | |
| new_count += 1 | |
| added += 1 | |
| print(f" {name}: {len(urls)} files | {added} new added to queue") | |
| save_state(state) | |
| print(f"\nTotal queued: {len(state['queue'])} | In state: {len(state['shards'])}") | |
| # ββ Reclaim stale βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reclaim_stale(state): | |
| now = time.time() | |
| reclaimed = 0 | |
| for name, info in state["shards"].items(): | |
| if info["status"] == "claimed" and info.get("claimed_at"): | |
| if now - info["claimed_at"] > WORKER_TIMEOUT: | |
| print(f" β Reclaiming: {name}") | |
| info["status"] = "pending" | |
| info["worker"] = None | |
| info["claimed_at"] = None | |
| reclaimed += 1 | |
| if reclaimed: | |
| save_state(state) | |
| # ββ Parquet β JSONL βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def parquet_to_jsonl(parquet_path, jsonl_path, text_col): | |
| """Stream parquet batch by batch β write one JSON line per doc. No full load.""" | |
| pf = pq.ParquetFile(parquet_path) | |
| n_written = 0 | |
| with open(jsonl_path, "w", encoding="utf-8") as out: | |
| for batch in pf.iter_batches(batch_size=1_000, columns=[text_col]): | |
| texts = batch.column(text_col).to_pylist() | |
| for text in texts: | |
| if text and isinstance(text, str) and text.strip(): | |
| out.write(json.dumps({"text": text.strip()}, ensure_ascii=False) + "\n") | |
| n_written += 1 | |
| del texts | |
| gc.collect() | |
| return n_written | |
| # ββ Download loop βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def download_loop(state): | |
| while True: | |
| try: | |
| with open(STATE_FILE) as f: | |
| fresh = json.load(f) | |
| state["shards"] = fresh["shards"] | |
| state["queue"] = fresh.get("queue", []) | |
| except Exception: | |
| pass | |
| reclaim_stale(state) | |
| buffered = sum(1 for v in state["shards"].values() if v["status"] == "pending") | |
| if buffered >= MAX_BUFFERED: | |
| time.sleep(30) | |
| continue | |
| if not state["queue"]: | |
| done = sum(1 for v in state["shards"].values() if v["status"] == "done") | |
| total = len(state["shards"]) | |
| if done == total and total > 0: | |
| print("β All shards complete!") | |
| break | |
| print(" Queue empty β sleeping...") | |
| time.sleep(60) | |
| continue | |
| entry = state["queue"][0] | |
| url = entry["url"] | |
| source = entry["source"] | |
| text_col = entry["text_col"] | |
| fmt = entry.get("fmt", "parquet") | |
| lang = url.split("?")[0].split("/")[-2] | |
| base_name = url.split("?")[0].split("/")[-1].replace(".parquet", "").replace(".jsonl", "") | |
| shard_name = f"{source}__{base_name}_{lang}.jsonl" | |
| jsonl_path = Path(RAW_DIR) / shard_name | |
| tmp_path = Path(RAW_DIR) / f"{shard_name}.tmp" | |
| print(f" Downloading: {source} | {base_name}") | |
| try: | |
| resp = requests.get(url, headers=AUTH_HEADERS, timeout=300, stream=True) | |
| resp.raise_for_status() | |
| with open(tmp_path, "wb") as f: | |
| for chunk in resp.iter_content(chunk_size=8 * 1024 * 1024): | |
| f.write(chunk) | |
| except Exception as e: | |
| print(f" β Download failed: {e} β retrying in 30s") | |
| tmp_path.unlink(missing_ok=True) | |
| time.sleep(30) | |
| continue | |
| if fmt == "parquet": | |
| print(f" Converting β jsonl: {shard_name}") | |
| try: | |
| n = parquet_to_jsonl(tmp_path, jsonl_path, text_col) | |
| tmp_path.unlink(missing_ok=True) | |
| print(f" β {n:,} docs") | |
| except Exception as e: | |
| print(f" β Convert failed: {e}") | |
| tmp_path.unlink(missing_ok=True) | |
| jsonl_path.unlink(missing_ok=True) | |
| time.sleep(30) | |
| continue | |
| else: | |
| tmp_path.rename(jsonl_path) | |
| state["queue"].pop(0) | |
| state["shards"][shard_name] = { | |
| "status" : "pending", | |
| "url" : url, | |
| "source" : source, | |
| "worker" : None, | |
| "claimed_at": None, | |
| "error" : None, | |
| } | |
| save_state(state) | |
| print(f" β Ready: {shard_name}") | |
| time.sleep(3) | |
| # ββ Monitor βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def monitor_loop(): | |
| while True: | |
| time.sleep(120) | |
| try: | |
| with open(STATE_FILE) as f: | |
| s = json.load(f) | |
| shards = s["shards"] | |
| queue = s.get("queue", []) | |
| done = sum(1 for v in shards.values() if v["status"] == "done") | |
| claimed = sum(1 for v in shards.values() if v["status"] == "claimed") | |
| pending = sum(1 for v in shards.values() if v["status"] == "pending") | |
| total = len(shards) + len(queue) | |
| pct = (done / total * 100) if total else 0 | |
| src_done = {} | |
| for v in shards.values(): | |
| src = v.get("source", "?") | |
| if v["status"] == "done": | |
| src_done[src] = src_done.get(src, 0) + 1 | |
| print(f"[MONITOR] {done}/{total} ({pct:.1f}%) | {claimed} active | {pending} buffered | {len(queue)} queued") | |
| for src, cnt in sorted(src_done.items()): | |
| print(f" {src}: {cnt} done") | |
| except Exception: | |
| pass | |
| # ββ Entry point βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| threading.Thread(target=serve, daemon=True).start() | |
| state = load_state() | |
| discover_all(state) | |
| threading.Thread(target=monitor_loop, daemon=True).start() | |
| threading.Thread(target=download_loop, args=(state,), daemon=True).start() | |
| while True: | |
| time.sleep(60) |