diffusers-pr-api / src /slop_farmer /data /snapshot_materialize.py
evalstate's picture
evalstate HF Staff
Deploy Diffusers PR API
dbf7313 verified
from __future__ import annotations
import json
import shutil
import urllib.parse
import urllib.request
from datetime import UTC, datetime
from pathlib import Path, PurePosixPath
from typing import Any
from huggingface_hub import HfApi, hf_hub_download
from slop_farmer.data.http import urlopen_with_retry
from slop_farmer.data.parquet_io import read_json, write_text
from slop_farmer.data.snapshot_paths import (
CONTRIBUTOR_ARTIFACT_FILENAMES,
CURRENT_ANALYSIS_MANIFEST_PATH,
LEGACY_ANALYSIS_FILENAMES,
PR_SCOPE_CLUSTERS_FILENAME,
RAW_TABLE_FILENAMES,
README_FILENAME,
ROOT_MANIFEST_FILENAME,
SNAPSHOTS_LATEST_PATH,
STATE_WATERMARK_PATH,
load_archived_analysis_run_manifest,
load_current_analysis_manifest,
repo_relative_path_to_local,
)
def materialize_hf_dataset_snapshot(
*,
repo_id: str,
local_dir: Path,
revision: str | None = None,
) -> Path:
info = _hf_dataset_info(repo_id=repo_id, revision=revision, files_metadata=True)
remote_paths = {sibling.rfilename for sibling in info.siblings}
resolved_revision = str(info.sha or revision or "main")
if SNAPSHOTS_LATEST_PATH in remote_paths:
return _materialize_hf_snapshot_repo_snapshot(
repo_id=repo_id,
local_dir=local_dir,
revision=resolved_revision,
requested_revision=revision,
hf_sha=info.sha,
remote_paths=remote_paths,
)
if {"issues.parquet", "pull_requests.parquet"} <= remote_paths:
return _materialize_hf_root_snapshot(
repo_id=repo_id,
local_dir=local_dir,
revision=resolved_revision,
requested_revision=revision,
hf_sha=info.sha,
remote_paths=remote_paths,
)
return _materialize_hf_dataset_viewer_snapshot(
repo_id=repo_id,
local_dir=local_dir,
revision=resolved_revision,
requested_revision=revision,
hf_sha=info.sha,
)
def _materialize_hf_snapshot_repo_snapshot(
*,
repo_id: str,
local_dir: Path,
revision: str,
requested_revision: str | None,
hf_sha: str | None,
remote_paths: set[str],
) -> Path:
local_dir.mkdir(parents=True, exist_ok=True)
latest_download = Path(
hf_hub_download(
repo_id=repo_id,
repo_type="dataset",
filename=SNAPSHOTS_LATEST_PATH,
revision=revision,
)
)
latest_payload = json.loads(latest_download.read_text(encoding="utf-8"))
downloaded_files: set[str] = set()
_copy_downloaded_file(
latest_download, repo_relative_path_to_local(local_dir, SNAPSHOTS_LATEST_PATH)
)
downloaded_files.add(SNAPSHOTS_LATEST_PATH)
for filename in (
*RAW_TABLE_FILENAMES,
ROOT_MANIFEST_FILENAME,
PR_SCOPE_CLUSTERS_FILENAME,
*CONTRIBUTOR_ARTIFACT_FILENAMES,
*LEGACY_ANALYSIS_FILENAMES,
):
downloaded = _download_first_available_hf_file(
repo_id=repo_id,
revision=revision,
filenames=_hf_latest_snapshot_candidates(latest_payload, filename),
)
if downloaded is None:
continue
_copy_downloaded_file(downloaded, local_dir / filename)
downloaded_files.add(filename)
if STATE_WATERMARK_PATH in remote_paths:
_download_repo_file(
repo_id=repo_id,
revision=revision,
local_dir=local_dir,
repo_path=STATE_WATERMARK_PATH,
downloaded_files=downloaded_files,
)
_download_analysis_state_files(
repo_id=repo_id,
revision=revision,
local_dir=local_dir,
remote_paths=remote_paths,
downloaded_files=downloaded_files,
)
_download_published_analysis_files(
repo_id=repo_id,
revision=revision,
local_dir=local_dir,
remote_paths=remote_paths,
downloaded_files=downloaded_files,
)
_download_repo_file(
repo_id=repo_id,
revision=revision,
local_dir=local_dir,
repo_path=README_FILENAME,
downloaded_files=downloaded_files,
required=False,
)
manifest = (
read_json(local_dir / ROOT_MANIFEST_FILENAME)
if (local_dir / ROOT_MANIFEST_FILENAME).exists()
else {}
)
manifest.setdefault("repo", _infer_repo_from_materialized_snapshot(local_dir))
manifest.setdefault(
"snapshot_id",
str(latest_payload.get("latest_snapshot_id") or hf_sha or local_dir.name),
)
manifest.update(
{
"source_type": "hf_snapshot_repo",
"hf_repo_id": repo_id,
"hf_revision": requested_revision,
"hf_resolved_revision": revision,
"hf_sha": hf_sha,
"materialized_at": _iso_now(),
"downloaded_files": sorted(downloaded_files),
"hf_latest_pointer": latest_payload,
}
)
write_text(json.dumps(manifest, indent=2) + "\n", local_dir / ROOT_MANIFEST_FILENAME)
return local_dir
def _materialize_hf_root_snapshot(
*,
repo_id: str,
local_dir: Path,
revision: str,
requested_revision: str | None,
hf_sha: str | None,
remote_paths: set[str],
) -> Path:
local_dir.mkdir(parents=True, exist_ok=True)
downloaded_files: set[str] = set()
for repo_path in (
*RAW_TABLE_FILENAMES,
ROOT_MANIFEST_FILENAME,
PR_SCOPE_CLUSTERS_FILENAME,
*CONTRIBUTOR_ARTIFACT_FILENAMES,
*LEGACY_ANALYSIS_FILENAMES,
SNAPSHOTS_LATEST_PATH,
STATE_WATERMARK_PATH,
README_FILENAME,
):
if repo_path not in remote_paths:
continue
_download_repo_file(
repo_id=repo_id,
revision=revision,
local_dir=local_dir,
repo_path=repo_path,
downloaded_files=downloaded_files,
)
_download_analysis_state_files(
repo_id=repo_id,
revision=revision,
local_dir=local_dir,
remote_paths=remote_paths,
downloaded_files=downloaded_files,
)
_download_published_analysis_files(
repo_id=repo_id,
revision=revision,
local_dir=local_dir,
remote_paths=remote_paths,
downloaded_files=downloaded_files,
)
manifest = (
read_json(local_dir / ROOT_MANIFEST_FILENAME)
if (local_dir / ROOT_MANIFEST_FILENAME).exists()
else {}
)
manifest.setdefault("repo", _infer_repo_from_materialized_snapshot(local_dir))
manifest.setdefault("snapshot_id", hf_sha or local_dir.name)
manifest.update(
{
"source_type": "hf_root_snapshot",
"hf_repo_id": repo_id,
"hf_revision": requested_revision,
"hf_resolved_revision": revision,
"hf_sha": hf_sha,
"materialized_at": _iso_now(),
"downloaded_files": sorted(downloaded_files),
}
)
write_text(json.dumps(manifest, indent=2) + "\n", local_dir / ROOT_MANIFEST_FILENAME)
return local_dir
def _materialize_hf_dataset_viewer_snapshot(
*,
repo_id: str,
local_dir: Path,
revision: str,
requested_revision: str | None,
hf_sha: str | None,
) -> Path:
local_dir.mkdir(parents=True, exist_ok=True)
downloaded_files: set[str] = set()
for index, url in enumerate(_hf_dataset_parquet_urls(repo_id, revision)):
temporary_path = local_dir / f"tmp-{index:04d}.parquet"
_download_url_to_path(url, temporary_path)
table_name = _parquet_table_name(temporary_path)
temporary_path.replace(local_dir / table_name)
downloaded_files.add(table_name)
readme_path = hf_hub_download(
repo_id=repo_id,
repo_type="dataset",
filename=README_FILENAME,
revision=revision,
)
shutil.copy2(readme_path, local_dir / README_FILENAME)
downloaded_files.add(README_FILENAME)
manifest = {
"repo": _infer_repo_from_materialized_snapshot(local_dir),
"snapshot_id": hf_sha or local_dir.name,
"source_type": "hf_dataset_viewer",
"hf_repo_id": repo_id,
"hf_revision": requested_revision,
"hf_resolved_revision": revision,
"hf_sha": hf_sha,
"materialized_at": _iso_now(),
"downloaded_files": sorted(downloaded_files),
}
write_text(json.dumps(manifest, indent=2) + "\n", local_dir / ROOT_MANIFEST_FILENAME)
return local_dir
def _download_published_analysis_files(
*,
repo_id: str,
revision: str,
local_dir: Path,
remote_paths: set[str],
downloaded_files: set[str],
) -> None:
if CURRENT_ANALYSIS_MANIFEST_PATH in remote_paths:
manifest_path = _download_repo_file(
repo_id=repo_id,
revision=revision,
local_dir=local_dir,
repo_path=CURRENT_ANALYSIS_MANIFEST_PATH,
downloaded_files=downloaded_files,
)
current_manifest = load_current_analysis_manifest(manifest_path)
for repo_path in _manifest_artifact_paths(current_manifest, include_archived=True):
if repo_path not in remote_paths:
continue
_download_repo_file(
repo_id=repo_id,
revision=revision,
local_dir=local_dir,
repo_path=repo_path,
downloaded_files=downloaded_files,
)
for repo_path in sorted(
path for path in remote_paths if _is_archived_analysis_manifest_path(path)
):
manifest_path = _download_repo_file(
repo_id=repo_id,
revision=revision,
local_dir=local_dir,
repo_path=repo_path,
downloaded_files=downloaded_files,
)
archived_manifest = load_archived_analysis_run_manifest(manifest_path)
for artifact_path in _manifest_artifact_paths(archived_manifest, include_archived=False):
if artifact_path not in remote_paths:
continue
_download_repo_file(
repo_id=repo_id,
revision=revision,
local_dir=local_dir,
repo_path=artifact_path,
downloaded_files=downloaded_files,
)
def _download_analysis_state_files(
*,
repo_id: str,
revision: str,
local_dir: Path,
remote_paths: set[str],
downloaded_files: set[str],
) -> None:
for repo_path in sorted(
path for path in remote_paths if PurePosixPath(path).parts[:1] == ("analysis-state",)
):
_download_repo_file(
repo_id=repo_id,
revision=revision,
local_dir=local_dir,
repo_path=repo_path,
downloaded_files=downloaded_files,
)
def _manifest_artifact_paths(
payload: dict[str, Any],
*,
include_archived: bool,
) -> list[str]:
paths = [
str(value) for value in (payload.get("artifacts") or {}).values() if isinstance(value, str)
]
if include_archived:
paths.extend(
str(value)
for value in (payload.get("archived_artifacts") or {}).values()
if isinstance(value, str)
)
deduped: list[str] = []
seen: set[str] = set()
for repo_path in paths:
normalized = repo_path.lstrip("./")
if not normalized or normalized in seen:
continue
seen.add(normalized)
deduped.append(normalized)
return deduped
def _is_archived_analysis_manifest_path(repo_path: str) -> bool:
parts = PurePosixPath(repo_path).parts
return (
len(parts) == 5
and parts[0] == "snapshots"
and parts[2] == "analysis-runs"
and parts[4] == ROOT_MANIFEST_FILENAME
)
def _download_repo_file(
*,
repo_id: str,
revision: str,
local_dir: Path,
repo_path: str,
downloaded_files: set[str],
required: bool = True,
) -> Path:
try:
downloaded = Path(
hf_hub_download(
repo_id=repo_id,
repo_type="dataset",
filename=repo_path,
revision=revision,
)
)
except Exception:
if required:
raise
return local_dir / repo_path
destination = repo_relative_path_to_local(local_dir, repo_path)
_copy_downloaded_file(downloaded, destination)
downloaded_files.add(repo_path)
return destination
def _copy_downloaded_file(downloaded_path: Path, destination: Path) -> None:
destination.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(downloaded_path, destination)
def _hf_dataset_info(repo_id: str, revision: str | None, *, files_metadata: bool) -> Any:
api = HfApi()
try:
return api.dataset_info(repo_id=repo_id, revision=revision, files_metadata=files_metadata)
except TypeError:
return api.dataset_info(repo_id=repo_id, revision=revision)
def _hf_dataset_parquet_urls(repo_id: str, revision: str | None = None) -> list[str]:
query = urllib.parse.urlencode({"revision": revision}) if revision else ""
api_url = (
f"https://huggingface.co/api/datasets/{urllib.parse.quote(repo_id, safe='')}/parquet"
f"{f'?{query}' if query else ''}"
)
with urlopen_with_retry(api_url, timeout=120, label=api_url) as response:
payload = json.loads(response.read().decode("utf-8"))
urls = payload.get("default", {}).get("train", [])
if not isinstance(urls, list) or not urls:
raise FileNotFoundError(
f"No parquet export URLs found for HF dataset {repo_id} at {api_url}"
)
return [str(url) for url in urls]
def _download_first_available_hf_file(
*,
repo_id: str,
revision: str,
filenames: list[str],
) -> Path | None:
for filename in filenames:
try:
downloaded = Path(
hf_hub_download(
repo_id=repo_id,
repo_type="dataset",
filename=filename,
revision=revision,
)
)
except Exception:
continue
if downloaded.exists():
return downloaded
return None
def _hf_latest_snapshot_candidates(latest_payload: dict[str, Any], filename: str) -> list[str]:
candidates: list[str] = []
manifest_path = str(latest_payload.get("manifest_path") or "").strip("/")
snapshot_dir = str(latest_payload.get("snapshot_dir") or "").strip("/")
latest_snapshot_id = str(latest_payload.get("latest_snapshot_id") or "").strip()
archived_manifest_path = str(latest_payload.get("archived_manifest_path") or "").strip("/")
if filename == ROOT_MANIFEST_FILENAME and manifest_path:
candidates.append(manifest_path)
if snapshot_dir and snapshot_dir not in {".", "/"}:
candidates.append(f"{snapshot_dir}/{filename}")
if filename == ROOT_MANIFEST_FILENAME and archived_manifest_path:
candidates.append(archived_manifest_path)
if manifest_path and "/" in manifest_path:
manifest_dir = manifest_path.rsplit("/", 1)[0]
candidates.append(f"{manifest_dir}/{filename}")
if latest_snapshot_id:
candidates.append(str(PurePosixPath("snapshots") / latest_snapshot_id / filename))
candidates.append(filename)
seen: set[str] = set()
deduped: list[str] = []
for candidate in candidates:
normalized = candidate.lstrip("./")
if not normalized or normalized in seen:
continue
seen.add(normalized)
deduped.append(normalized)
return deduped
def _download_url_to_path(url: str, destination: Path) -> None:
destination.parent.mkdir(parents=True, exist_ok=True)
urllib.request.urlretrieve(url, destination)
def _parquet_table_name(path: Path) -> str:
import pyarrow.parquet as pq
columns = set(pq.read_table(path).column_names)
if {"parent_kind", "issue_api_url", "body"} <= columns:
return "comments.parquet"
if {"event", "source_issue_number", "source_issue_url"} <= columns:
return "events.parquet"
if {"milestone_title", "comments_count"} <= columns and "merged_at" not in columns:
return "issues.parquet"
if {"link_type", "link_origin", "target_number"} <= columns:
return "links.parquet"
if {"pull_request_number", "filename", "blob_url", "patch"} <= columns:
return "pr_files.parquet"
if {"pull_request_number", "diff", "html_url", "api_url"} <= columns:
return "pr_diffs.parquet"
if {"merged_at", "head_ref", "base_ref"} <= columns:
return "pull_requests.parquet"
if {"review_id", "pull_request_api_url", "path"} <= columns:
return "review_comments.parquet"
if {"pull_request_number", "submitted_at"} <= columns and "review_id" not in columns:
return "reviews.parquet"
raise ValueError(f"Unrecognized HF parquet schema for {path.name}: {sorted(columns)}")
def _infer_repo_from_materialized_snapshot(local_dir: Path) -> str:
import pyarrow.parquet as pq
for table_filename in RAW_TABLE_FILENAMES:
path = local_dir / table_filename
if not path.exists():
continue
rows = pq.read_table(path).slice(0, 1).to_pylist()
if rows and rows[0].get("repo"):
return str(rows[0]["repo"])
raise FileNotFoundError(f"Could not infer repo from materialized snapshot in {local_dir}")
def _iso_now() -> str:
return datetime.now(tz=UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z")