diffusers-pr-api / src /slop_farmer /app /hf_checkpoint_import.py
evalstate's picture
evalstate HF Staff
Deploy Diffusers PR API
dbf7313 verified
"""Import a historical Hugging Face dataset checkpoint into a clean snapshot.
This is mainly for legacy datasets that kept their richest data under
`_checkpoints/<snapshot_id>/...` instead of promoting the latest full snapshot
to the dataset root.
The importer:
1. selects a checkpoint from a source HF dataset repo
2. downloads the checkpoint parquet files
3. rewrites small tables to the current local schema when needed
4. regenerates derived artifacts (`links`, `issue_comments`, `pr_comments`)
5. writes a normal local snapshot directory that `analyze` / `pr-scope` can use
6. optionally republishes that clean snapshot
"""
from __future__ import annotations
import re
import shutil
from collections import defaultdict
from collections.abc import Mapping
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from huggingface_hub import HfApi, hf_hub_download
from slop_farmer.app.publish_dataset_snapshot import publish_dataset_snapshot
from slop_farmer.config import CheckpointImportOptions
from slop_farmer.data.dataset_card import build_hf_dataset_card
from slop_farmer.data.links import build_pr_duplicate_candidate_rows, build_text_link_rows
from slop_farmer.data.parquet_io import (
SCHEMAS,
read_json,
read_parquet_rows,
write_json,
write_parquet,
write_text,
)
CHECKPOINT_PATH_PATTERN = re.compile(
r"^(?P<root>_?checkpoints)/(?P<snapshot_id>[^/]+)/(?P<filename>[^/]+)$"
)
REQUIRED_CHECKPOINT_FILES = (
"issues.parquet",
"pull_requests.parquet",
"comments.parquet",
"reviews.parquet",
"review_comments.parquet",
"pr_files.parquet",
"pr_diffs.parquet",
"events.parquet",
)
SMALL_TABLES = (
"issues",
"pull_requests",
"comments",
"reviews",
"review_comments",
"events",
)
COPIED_TABLES = (
"pr_files",
"pr_diffs",
)
__all__ = [
"import_hf_checkpoint",
"materialize_checkpoint_snapshot",
]
def import_hf_checkpoint(options: CheckpointImportOptions) -> Path:
api = HfApi()
siblings = [
sibling.rfilename
for sibling in api.dataset_info(repo_id=options.source_repo_id).siblings or []
]
checkpoint_root, checkpoint_id = _select_checkpoint(
siblings,
checkpoint_id=options.checkpoint_id,
checkpoint_root=options.checkpoint_root,
)
checkpoint_files = {
filename: Path(
hf_hub_download(
repo_id=options.source_repo_id,
repo_type="dataset",
filename=f"{checkpoint_root}/{checkpoint_id}/{filename}",
)
)
for filename in REQUIRED_CHECKPOINT_FILES
}
source_manifest = _maybe_download_json(options.source_repo_id, "manifest.json")
source_progress = _maybe_download_json(
options.source_repo_id, f"{checkpoint_root}/{checkpoint_id}/progress.json"
)
snapshot_dir = materialize_checkpoint_snapshot(
checkpoint_files,
output_root=options.output_dir,
source_repo_id=options.source_repo_id,
checkpoint_root=checkpoint_root,
checkpoint_id=checkpoint_id,
source_manifest=source_manifest,
source_progress=source_progress,
force=options.force,
)
if options.publish_repo_id:
publish_dataset_snapshot(
snapshot_dir, options.publish_repo_id, private=options.private_hf_repo
)
return snapshot_dir
def materialize_checkpoint_snapshot(
checkpoint_files: Mapping[str, Path],
*,
output_root: Path,
source_repo_id: str,
checkpoint_root: str,
checkpoint_id: str,
source_manifest: Mapping[str, Any] | None = None,
source_progress: Mapping[str, Any] | None = None,
force: bool = False,
) -> Path:
source_manifest = dict(source_manifest or {})
source_progress = dict(source_progress or {})
output_root = output_root.resolve()
snapshot_dir = output_root / "snapshots" / _snapshot_dir_name(source_repo_id, checkpoint_id)
if snapshot_dir.exists():
if not force:
raise FileExistsError(f"Snapshot already exists: {snapshot_dir}")
shutil.rmtree(snapshot_dir)
snapshot_dir.mkdir(parents=True, exist_ok=True)
rows_by_table: dict[str, list[dict[str, Any]]] = {}
for table_name in SMALL_TABLES:
rows = read_parquet_rows(checkpoint_files[f"{table_name}.parquet"])
rows_by_table[table_name] = _project_rows(rows, table_name)
write_parquet(rows_by_table[table_name], snapshot_dir / f"{table_name}.parquet", table_name)
for table_name in COPIED_TABLES:
shutil.copy2(
checkpoint_files[f"{table_name}.parquet"], snapshot_dir / f"{table_name}.parquet"
)
repo_slug = _resolve_repo(rows_by_table, source_manifest, source_progress)
extracted_at = _resolve_extracted_at(rows_by_table, source_manifest)
link_rows = _derived_link_rows(
repo_slug=repo_slug,
snapshot_id=checkpoint_id,
extracted_at=extracted_at,
issues=rows_by_table["issues"],
pull_requests=rows_by_table["pull_requests"],
comments=rows_by_table["comments"],
reviews=rows_by_table["reviews"],
review_comments=rows_by_table["review_comments"],
events=rows_by_table["events"],
)
write_parquet(link_rows, snapshot_dir / "links.parquet", "links")
issue_comment_rows, pr_comment_rows = _viewer_comment_rows(
rows_by_table["comments"], rows_by_table["pull_requests"]
)
write_parquet(issue_comment_rows, snapshot_dir / "issue_comments.parquet", "comments")
write_parquet(pr_comment_rows, snapshot_dir / "pr_comments.parquet", "comments")
manifest = {
"repo": repo_slug,
"snapshot_id": checkpoint_id,
"extracted_at": extracted_at,
"imported_at": _iso_now(),
"source_type": "hf_checkpoint_import",
"source_hf_repo_id": source_repo_id,
"source_checkpoint_root": checkpoint_root,
"source_checkpoint_id": checkpoint_id,
"source_manifest": source_manifest,
"source_progress": source_progress,
"counts": {
"issues": len(rows_by_table["issues"]),
"pull_requests": len(rows_by_table["pull_requests"]),
"comments": len(rows_by_table["comments"]),
"reviews": len(rows_by_table["reviews"]),
"review_comments": len(rows_by_table["review_comments"]),
"pr_files": _parquet_row_count(snapshot_dir / "pr_files.parquet"),
"pr_diffs": _parquet_row_count(snapshot_dir / "pr_diffs.parquet"),
"timeline_events": len(rows_by_table["events"]),
"links": len(link_rows),
},
"notes": (
"Imported from a remote HF dataset checkpoint and regenerated locally derived link/comment views."
),
}
write_json(manifest, snapshot_dir / "manifest.json")
write_text(
_dataset_card(repo_slug, checkpoint_id, source_repo_id, checkpoint_root),
snapshot_dir / "README.md",
)
return snapshot_dir
def _select_checkpoint(
sibling_paths: list[str],
*,
checkpoint_id: str | None,
checkpoint_root: str | None,
) -> tuple[str, str]:
candidates: defaultdict[tuple[str, str], set[str]] = defaultdict(set)
for path in sibling_paths:
match = CHECKPOINT_PATH_PATTERN.match(path)
if match is None:
continue
root = match.group("root")
snapshot_id = match.group("snapshot_id")
filename = match.group("filename")
candidates[(root, snapshot_id)].add(filename)
viable = [
(root, snapshot_id)
for (root, snapshot_id), filenames in candidates.items()
if all(filename in filenames for filename in REQUIRED_CHECKPOINT_FILES)
]
if checkpoint_id is not None:
requested = [
candidate
for candidate in viable
if candidate[1] == checkpoint_id
and (checkpoint_root is None or candidate[0] == checkpoint_root)
]
if not requested:
raise ValueError(
f"Checkpoint {checkpoint_root or '*'}:{checkpoint_id} not found with required files."
)
return sorted(requested, key=_checkpoint_sort_key)[-1]
if checkpoint_root is not None:
viable = [candidate for candidate in viable if candidate[0] == checkpoint_root]
if not viable:
raise ValueError("No viable checkpoint directories were found in the source HF dataset.")
return sorted(viable, key=_checkpoint_sort_key)[-1]
def _checkpoint_sort_key(candidate: tuple[str, str]) -> tuple[str, int]:
root, snapshot_id = candidate
root_priority = 1 if root == "_checkpoints" else 0
return snapshot_id, root_priority
def _maybe_download_json(repo_id: str, filename: str) -> dict[str, Any]:
try:
path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=filename)
except Exception:
return {}
return read_json(Path(path))
def _project_rows(rows: list[dict[str, Any]], table_name: str) -> list[dict[str, Any]]:
field_names = [field.name for field in SCHEMAS[table_name]]
return [{field_name: row.get(field_name) for field_name in field_names} for row in rows]
def _resolve_repo(
rows_by_table: Mapping[str, list[dict[str, Any]]],
source_manifest: Mapping[str, Any],
source_progress: Mapping[str, Any],
) -> str:
for table_name in ("issues", "pull_requests", "comments"):
rows = rows_by_table.get(table_name) or []
if rows and rows[0].get("repo"):
return str(rows[0]["repo"])
return str(source_manifest.get("repo") or source_progress.get("repo") or "")
def _resolve_extracted_at(
rows_by_table: Mapping[str, list[dict[str, Any]]],
source_manifest: Mapping[str, Any],
) -> str:
for table_name in (
"issues",
"pull_requests",
"comments",
"reviews",
"review_comments",
"events",
):
rows = rows_by_table.get(table_name) or []
if rows and rows[0].get("extracted_at"):
return str(rows[0]["extracted_at"])
return str(source_manifest.get("extracted_at") or _iso_now())
def _derived_link_rows(
*,
repo_slug: str,
snapshot_id: str,
extracted_at: str,
issues: list[dict[str, Any]],
pull_requests: list[dict[str, Any]],
comments: list[dict[str, Any]],
reviews: list[dict[str, Any]],
review_comments: list[dict[str, Any]],
events: list[dict[str, Any]],
) -> list[dict[str, Any]]:
owner, repo_name = repo_slug.split("/", 1)
link_rows: list[dict[str, Any]] = []
for issue_row in issues:
link_rows.extend(
build_text_link_rows(
repo=repo_slug,
owner=owner,
repo_name=repo_name,
source_type="issue",
source_number=int(issue_row["number"]),
source_id=issue_row.get("github_id"),
body=issue_row.get("body"),
snapshot_id=snapshot_id,
extracted_at=extracted_at,
)
)
for pr_row in pull_requests:
link_rows.extend(
build_text_link_rows(
repo=repo_slug,
owner=owner,
repo_name=repo_name,
source_type="pull_request",
source_number=int(pr_row["number"]),
source_id=pr_row.get("github_id"),
body=pr_row.get("body"),
snapshot_id=snapshot_id,
extracted_at=extracted_at,
)
)
for comment_row in comments:
parent_number = comment_row.get("parent_number")
if parent_number is None:
continue
link_rows.extend(
build_text_link_rows(
repo=repo_slug,
owner=owner,
repo_name=repo_name,
source_type="comment",
source_number=int(parent_number),
source_id=comment_row.get("github_id"),
body=comment_row.get("body"),
snapshot_id=snapshot_id,
extracted_at=extracted_at,
)
)
for review_row in reviews:
link_rows.extend(
build_text_link_rows(
repo=repo_slug,
owner=owner,
repo_name=repo_name,
source_type="review",
source_number=int(review_row["pull_request_number"]),
source_id=review_row.get("github_id"),
body=review_row.get("body"),
snapshot_id=snapshot_id,
extracted_at=extracted_at,
)
)
for review_comment_row in review_comments:
link_rows.extend(
build_text_link_rows(
repo=repo_slug,
owner=owner,
repo_name=repo_name,
source_type="review_comment",
source_number=int(review_comment_row["pull_request_number"]),
source_id=review_comment_row.get("github_id"),
body=review_comment_row.get("body"),
snapshot_id=snapshot_id,
extracted_at=extracted_at,
)
)
link_rows.extend(
build_pr_duplicate_candidate_rows(
repo=repo_slug,
pull_requests=pull_requests,
link_rows=link_rows,
snapshot_id=snapshot_id,
extracted_at=extracted_at,
)
)
for event in events:
source_issue_number = event.get("source_issue_number")
if not source_issue_number:
continue
link_rows.append(
{
"repo": repo_slug,
"source_type": event["parent_kind"],
"source_number": event["parent_number"],
"source_github_id": None,
"target_owner": owner,
"target_repo": repo_name,
"target_number": source_issue_number,
"link_type": f"timeline:{event['event']}",
"link_origin": "timeline",
"snapshot_id": snapshot_id,
"extracted_at": extracted_at,
}
)
return sorted(
_dedupe_link_rows(link_rows),
key=lambda row: (
str(row.get("source_type") or ""),
int(row.get("source_number") or 0),
int(row.get("source_github_id") or 0),
str(row.get("target_owner") or ""),
str(row.get("target_repo") or ""),
int(row.get("target_number") or 0),
str(row.get("link_type") or ""),
str(row.get("link_origin") or ""),
),
)
def _dedupe_link_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
deduped: dict[tuple[Any, ...], dict[str, Any]] = {}
for row in rows:
key = (
row.get("repo"),
row.get("source_type"),
row.get("source_number"),
row.get("source_github_id"),
row.get("target_owner"),
row.get("target_repo"),
row.get("target_number"),
row.get("link_type"),
row.get("link_origin"),
)
deduped[key] = row
return list(deduped.values())
def _viewer_comment_rows(
comments: list[dict[str, Any]],
pull_requests: list[dict[str, Any]],
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
pr_numbers = {int(row["number"]) for row in pull_requests if row.get("number") is not None}
issue_comments: list[dict[str, Any]] = []
pr_comments: list[dict[str, Any]] = []
for row in comments:
parent_number = row.get("parent_number")
parent_kind = row.get("parent_kind")
if parent_kind == "pull_request" or parent_number in pr_numbers:
pr_comments.append(row)
else:
issue_comments.append(row)
return issue_comments, pr_comments
def _dataset_card(
repo_slug: str, snapshot_id: str, source_repo_id: str, checkpoint_root: str
) -> str:
return build_hf_dataset_card(
repo_slug,
snapshot_id,
notes=[
f"source HF dataset: `{source_repo_id}`",
f"source checkpoint root: `{checkpoint_root}`",
"links were regenerated locally from text references and timeline events",
],
)
def _snapshot_dir_name(source_repo_id: str, checkpoint_id: str) -> str:
return f"hf-{source_repo_id.replace('/', '--')}-{checkpoint_id}"
def _iso_now() -> str:
return datetime.now(tz=UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z")
def _parquet_row_count(path: Path) -> int:
import pyarrow.parquet as pq
return pq.read_metadata(path).num_rows