"""Import a historical Hugging Face dataset checkpoint into a clean snapshot. This is mainly for legacy datasets that kept their richest data under `_checkpoints//...` instead of promoting the latest full snapshot to the dataset root. The importer: 1. selects a checkpoint from a source HF dataset repo 2. downloads the checkpoint parquet files 3. rewrites small tables to the current local schema when needed 4. regenerates derived artifacts (`links`, `issue_comments`, `pr_comments`) 5. writes a normal local snapshot directory that `analyze` / `pr-scope` can use 6. optionally republishes that clean snapshot """ from __future__ import annotations import re import shutil from collections import defaultdict from collections.abc import Mapping from datetime import UTC, datetime from pathlib import Path from typing import Any from huggingface_hub import HfApi, hf_hub_download from slop_farmer.app.publish_dataset_snapshot import publish_dataset_snapshot from slop_farmer.config import CheckpointImportOptions from slop_farmer.data.dataset_card import build_hf_dataset_card from slop_farmer.data.links import build_pr_duplicate_candidate_rows, build_text_link_rows from slop_farmer.data.parquet_io import ( SCHEMAS, read_json, read_parquet_rows, write_json, write_parquet, write_text, ) CHECKPOINT_PATH_PATTERN = re.compile( r"^(?P_?checkpoints)/(?P[^/]+)/(?P[^/]+)$" ) REQUIRED_CHECKPOINT_FILES = ( "issues.parquet", "pull_requests.parquet", "comments.parquet", "reviews.parquet", "review_comments.parquet", "pr_files.parquet", "pr_diffs.parquet", "events.parquet", ) SMALL_TABLES = ( "issues", "pull_requests", "comments", "reviews", "review_comments", "events", ) COPIED_TABLES = ( "pr_files", "pr_diffs", ) __all__ = [ "import_hf_checkpoint", "materialize_checkpoint_snapshot", ] def import_hf_checkpoint(options: CheckpointImportOptions) -> Path: api = HfApi() siblings = [ sibling.rfilename for sibling in api.dataset_info(repo_id=options.source_repo_id).siblings or [] ] checkpoint_root, checkpoint_id = _select_checkpoint( siblings, checkpoint_id=options.checkpoint_id, checkpoint_root=options.checkpoint_root, ) checkpoint_files = { filename: Path( hf_hub_download( repo_id=options.source_repo_id, repo_type="dataset", filename=f"{checkpoint_root}/{checkpoint_id}/{filename}", ) ) for filename in REQUIRED_CHECKPOINT_FILES } source_manifest = _maybe_download_json(options.source_repo_id, "manifest.json") source_progress = _maybe_download_json( options.source_repo_id, f"{checkpoint_root}/{checkpoint_id}/progress.json" ) snapshot_dir = materialize_checkpoint_snapshot( checkpoint_files, output_root=options.output_dir, source_repo_id=options.source_repo_id, checkpoint_root=checkpoint_root, checkpoint_id=checkpoint_id, source_manifest=source_manifest, source_progress=source_progress, force=options.force, ) if options.publish_repo_id: publish_dataset_snapshot( snapshot_dir, options.publish_repo_id, private=options.private_hf_repo ) return snapshot_dir def materialize_checkpoint_snapshot( checkpoint_files: Mapping[str, Path], *, output_root: Path, source_repo_id: str, checkpoint_root: str, checkpoint_id: str, source_manifest: Mapping[str, Any] | None = None, source_progress: Mapping[str, Any] | None = None, force: bool = False, ) -> Path: source_manifest = dict(source_manifest or {}) source_progress = dict(source_progress or {}) output_root = output_root.resolve() snapshot_dir = output_root / "snapshots" / _snapshot_dir_name(source_repo_id, checkpoint_id) if snapshot_dir.exists(): if not force: raise FileExistsError(f"Snapshot already exists: {snapshot_dir}") shutil.rmtree(snapshot_dir) snapshot_dir.mkdir(parents=True, exist_ok=True) rows_by_table: dict[str, list[dict[str, Any]]] = {} for table_name in SMALL_TABLES: rows = read_parquet_rows(checkpoint_files[f"{table_name}.parquet"]) rows_by_table[table_name] = _project_rows(rows, table_name) write_parquet(rows_by_table[table_name], snapshot_dir / f"{table_name}.parquet", table_name) for table_name in COPIED_TABLES: shutil.copy2( checkpoint_files[f"{table_name}.parquet"], snapshot_dir / f"{table_name}.parquet" ) repo_slug = _resolve_repo(rows_by_table, source_manifest, source_progress) extracted_at = _resolve_extracted_at(rows_by_table, source_manifest) link_rows = _derived_link_rows( repo_slug=repo_slug, snapshot_id=checkpoint_id, extracted_at=extracted_at, issues=rows_by_table["issues"], pull_requests=rows_by_table["pull_requests"], comments=rows_by_table["comments"], reviews=rows_by_table["reviews"], review_comments=rows_by_table["review_comments"], events=rows_by_table["events"], ) write_parquet(link_rows, snapshot_dir / "links.parquet", "links") issue_comment_rows, pr_comment_rows = _viewer_comment_rows( rows_by_table["comments"], rows_by_table["pull_requests"] ) write_parquet(issue_comment_rows, snapshot_dir / "issue_comments.parquet", "comments") write_parquet(pr_comment_rows, snapshot_dir / "pr_comments.parquet", "comments") manifest = { "repo": repo_slug, "snapshot_id": checkpoint_id, "extracted_at": extracted_at, "imported_at": _iso_now(), "source_type": "hf_checkpoint_import", "source_hf_repo_id": source_repo_id, "source_checkpoint_root": checkpoint_root, "source_checkpoint_id": checkpoint_id, "source_manifest": source_manifest, "source_progress": source_progress, "counts": { "issues": len(rows_by_table["issues"]), "pull_requests": len(rows_by_table["pull_requests"]), "comments": len(rows_by_table["comments"]), "reviews": len(rows_by_table["reviews"]), "review_comments": len(rows_by_table["review_comments"]), "pr_files": _parquet_row_count(snapshot_dir / "pr_files.parquet"), "pr_diffs": _parquet_row_count(snapshot_dir / "pr_diffs.parquet"), "timeline_events": len(rows_by_table["events"]), "links": len(link_rows), }, "notes": ( "Imported from a remote HF dataset checkpoint and regenerated locally derived link/comment views." ), } write_json(manifest, snapshot_dir / "manifest.json") write_text( _dataset_card(repo_slug, checkpoint_id, source_repo_id, checkpoint_root), snapshot_dir / "README.md", ) return snapshot_dir def _select_checkpoint( sibling_paths: list[str], *, checkpoint_id: str | None, checkpoint_root: str | None, ) -> tuple[str, str]: candidates: defaultdict[tuple[str, str], set[str]] = defaultdict(set) for path in sibling_paths: match = CHECKPOINT_PATH_PATTERN.match(path) if match is None: continue root = match.group("root") snapshot_id = match.group("snapshot_id") filename = match.group("filename") candidates[(root, snapshot_id)].add(filename) viable = [ (root, snapshot_id) for (root, snapshot_id), filenames in candidates.items() if all(filename in filenames for filename in REQUIRED_CHECKPOINT_FILES) ] if checkpoint_id is not None: requested = [ candidate for candidate in viable if candidate[1] == checkpoint_id and (checkpoint_root is None or candidate[0] == checkpoint_root) ] if not requested: raise ValueError( f"Checkpoint {checkpoint_root or '*'}:{checkpoint_id} not found with required files." ) return sorted(requested, key=_checkpoint_sort_key)[-1] if checkpoint_root is not None: viable = [candidate for candidate in viable if candidate[0] == checkpoint_root] if not viable: raise ValueError("No viable checkpoint directories were found in the source HF dataset.") return sorted(viable, key=_checkpoint_sort_key)[-1] def _checkpoint_sort_key(candidate: tuple[str, str]) -> tuple[str, int]: root, snapshot_id = candidate root_priority = 1 if root == "_checkpoints" else 0 return snapshot_id, root_priority def _maybe_download_json(repo_id: str, filename: str) -> dict[str, Any]: try: path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=filename) except Exception: return {} return read_json(Path(path)) def _project_rows(rows: list[dict[str, Any]], table_name: str) -> list[dict[str, Any]]: field_names = [field.name for field in SCHEMAS[table_name]] return [{field_name: row.get(field_name) for field_name in field_names} for row in rows] def _resolve_repo( rows_by_table: Mapping[str, list[dict[str, Any]]], source_manifest: Mapping[str, Any], source_progress: Mapping[str, Any], ) -> str: for table_name in ("issues", "pull_requests", "comments"): rows = rows_by_table.get(table_name) or [] if rows and rows[0].get("repo"): return str(rows[0]["repo"]) return str(source_manifest.get("repo") or source_progress.get("repo") or "") def _resolve_extracted_at( rows_by_table: Mapping[str, list[dict[str, Any]]], source_manifest: Mapping[str, Any], ) -> str: for table_name in ( "issues", "pull_requests", "comments", "reviews", "review_comments", "events", ): rows = rows_by_table.get(table_name) or [] if rows and rows[0].get("extracted_at"): return str(rows[0]["extracted_at"]) return str(source_manifest.get("extracted_at") or _iso_now()) def _derived_link_rows( *, repo_slug: str, snapshot_id: str, extracted_at: str, issues: list[dict[str, Any]], pull_requests: list[dict[str, Any]], comments: list[dict[str, Any]], reviews: list[dict[str, Any]], review_comments: list[dict[str, Any]], events: list[dict[str, Any]], ) -> list[dict[str, Any]]: owner, repo_name = repo_slug.split("/", 1) link_rows: list[dict[str, Any]] = [] for issue_row in issues: link_rows.extend( build_text_link_rows( repo=repo_slug, owner=owner, repo_name=repo_name, source_type="issue", source_number=int(issue_row["number"]), source_id=issue_row.get("github_id"), body=issue_row.get("body"), snapshot_id=snapshot_id, extracted_at=extracted_at, ) ) for pr_row in pull_requests: link_rows.extend( build_text_link_rows( repo=repo_slug, owner=owner, repo_name=repo_name, source_type="pull_request", source_number=int(pr_row["number"]), source_id=pr_row.get("github_id"), body=pr_row.get("body"), snapshot_id=snapshot_id, extracted_at=extracted_at, ) ) for comment_row in comments: parent_number = comment_row.get("parent_number") if parent_number is None: continue link_rows.extend( build_text_link_rows( repo=repo_slug, owner=owner, repo_name=repo_name, source_type="comment", source_number=int(parent_number), source_id=comment_row.get("github_id"), body=comment_row.get("body"), snapshot_id=snapshot_id, extracted_at=extracted_at, ) ) for review_row in reviews: link_rows.extend( build_text_link_rows( repo=repo_slug, owner=owner, repo_name=repo_name, source_type="review", source_number=int(review_row["pull_request_number"]), source_id=review_row.get("github_id"), body=review_row.get("body"), snapshot_id=snapshot_id, extracted_at=extracted_at, ) ) for review_comment_row in review_comments: link_rows.extend( build_text_link_rows( repo=repo_slug, owner=owner, repo_name=repo_name, source_type="review_comment", source_number=int(review_comment_row["pull_request_number"]), source_id=review_comment_row.get("github_id"), body=review_comment_row.get("body"), snapshot_id=snapshot_id, extracted_at=extracted_at, ) ) link_rows.extend( build_pr_duplicate_candidate_rows( repo=repo_slug, pull_requests=pull_requests, link_rows=link_rows, snapshot_id=snapshot_id, extracted_at=extracted_at, ) ) for event in events: source_issue_number = event.get("source_issue_number") if not source_issue_number: continue link_rows.append( { "repo": repo_slug, "source_type": event["parent_kind"], "source_number": event["parent_number"], "source_github_id": None, "target_owner": owner, "target_repo": repo_name, "target_number": source_issue_number, "link_type": f"timeline:{event['event']}", "link_origin": "timeline", "snapshot_id": snapshot_id, "extracted_at": extracted_at, } ) return sorted( _dedupe_link_rows(link_rows), key=lambda row: ( str(row.get("source_type") or ""), int(row.get("source_number") or 0), int(row.get("source_github_id") or 0), str(row.get("target_owner") or ""), str(row.get("target_repo") or ""), int(row.get("target_number") or 0), str(row.get("link_type") or ""), str(row.get("link_origin") or ""), ), ) def _dedupe_link_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: deduped: dict[tuple[Any, ...], dict[str, Any]] = {} for row in rows: key = ( row.get("repo"), row.get("source_type"), row.get("source_number"), row.get("source_github_id"), row.get("target_owner"), row.get("target_repo"), row.get("target_number"), row.get("link_type"), row.get("link_origin"), ) deduped[key] = row return list(deduped.values()) def _viewer_comment_rows( comments: list[dict[str, Any]], pull_requests: list[dict[str, Any]], ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: pr_numbers = {int(row["number"]) for row in pull_requests if row.get("number") is not None} issue_comments: list[dict[str, Any]] = [] pr_comments: list[dict[str, Any]] = [] for row in comments: parent_number = row.get("parent_number") parent_kind = row.get("parent_kind") if parent_kind == "pull_request" or parent_number in pr_numbers: pr_comments.append(row) else: issue_comments.append(row) return issue_comments, pr_comments def _dataset_card( repo_slug: str, snapshot_id: str, source_repo_id: str, checkpoint_root: str ) -> str: return build_hf_dataset_card( repo_slug, snapshot_id, notes=[ f"source HF dataset: `{source_repo_id}`", f"source checkpoint root: `{checkpoint_root}`", "links were regenerated locally from text references and timeline events", ], ) def _snapshot_dir_name(source_repo_id: str, checkpoint_id: str) -> str: return f"hf-{source_repo_id.replace('/', '--')}-{checkpoint_id}" def _iso_now() -> str: return datetime.now(tz=UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z") def _parquet_row_count(path: Path) -> int: import pyarrow.parquet as pq return pq.read_metadata(path).num_rows