diffusers-pr-api / src /slop_farmer /data /parquet_io.py
evalstate's picture
evalstate HF Staff
Deploy Diffusers PR API
dbf7313 verified
from __future__ import annotations
import json
import os
import tempfile
from pathlib import Path
from typing import Any
import pyarrow as pa
import pyarrow.parquet as pq
SCHEMAS: dict[str, pa.Schema] = {
"issues": pa.schema(
[
("repo", pa.string()),
("github_id", pa.int64()),
("github_node_id", pa.string()),
("number", pa.int64()),
("html_url", pa.string()),
("api_url", pa.string()),
("title", pa.string()),
("body", pa.string()),
("state", pa.string()),
("state_reason", pa.string()),
("locked", pa.bool_()),
("comments_count", pa.int64()),
("labels", pa.list_(pa.string())),
("assignees", pa.list_(pa.string())),
("created_at", pa.string()),
("updated_at", pa.string()),
("closed_at", pa.string()),
("author_association", pa.string()),
("milestone_title", pa.string()),
("snapshot_id", pa.string()),
("extracted_at", pa.string()),
("author_login", pa.string()),
("author_id", pa.int64()),
("author_node_id", pa.string()),
("author_type", pa.string()),
("author_site_admin", pa.bool_()),
]
),
"pull_requests": pa.schema(
[
("repo", pa.string()),
("github_id", pa.int64()),
("github_node_id", pa.string()),
("number", pa.int64()),
("html_url", pa.string()),
("api_url", pa.string()),
("title", pa.string()),
("body", pa.string()),
("state", pa.string()),
("state_reason", pa.string()),
("locked", pa.bool_()),
("comments_count", pa.int64()),
("labels", pa.list_(pa.string())),
("assignees", pa.list_(pa.string())),
("created_at", pa.string()),
("updated_at", pa.string()),
("closed_at", pa.string()),
("author_association", pa.string()),
("merged_at", pa.string()),
("merge_commit_sha", pa.string()),
("merged", pa.bool_()),
("mergeable", pa.bool_()),
("mergeable_state", pa.string()),
("draft", pa.bool_()),
("additions", pa.int64()),
("deletions", pa.int64()),
("changed_files", pa.int64()),
("commits", pa.int64()),
("review_comments_count", pa.int64()),
("maintainer_can_modify", pa.bool_()),
("head_ref", pa.string()),
("head_sha", pa.string()),
("head_repo_full_name", pa.string()),
("base_ref", pa.string()),
("base_sha", pa.string()),
("base_repo_full_name", pa.string()),
("snapshot_id", pa.string()),
("extracted_at", pa.string()),
("author_login", pa.string()),
("author_id", pa.int64()),
("author_node_id", pa.string()),
("author_type", pa.string()),
("author_site_admin", pa.bool_()),
]
),
"comments": pa.schema(
[
("repo", pa.string()),
("github_id", pa.int64()),
("github_node_id", pa.string()),
("parent_kind", pa.string()),
("parent_number", pa.int64()),
("html_url", pa.string()),
("api_url", pa.string()),
("issue_api_url", pa.string()),
("body", pa.string()),
("created_at", pa.string()),
("updated_at", pa.string()),
("author_association", pa.string()),
("snapshot_id", pa.string()),
("extracted_at", pa.string()),
("author_login", pa.string()),
("author_id", pa.int64()),
("author_node_id", pa.string()),
("author_type", pa.string()),
("author_site_admin", pa.bool_()),
]
),
"reviews": pa.schema(
[
("repo", pa.string()),
("github_id", pa.int64()),
("github_node_id", pa.string()),
("pull_request_number", pa.int64()),
("html_url", pa.string()),
("api_url", pa.string()),
("body", pa.string()),
("state", pa.string()),
("submitted_at", pa.string()),
("commit_id", pa.string()),
("author_association", pa.string()),
("snapshot_id", pa.string()),
("extracted_at", pa.string()),
("author_login", pa.string()),
("author_id", pa.int64()),
("author_node_id", pa.string()),
("author_type", pa.string()),
("author_site_admin", pa.bool_()),
]
),
"review_comments": pa.schema(
[
("repo", pa.string()),
("github_id", pa.int64()),
("github_node_id", pa.string()),
("pull_request_number", pa.int64()),
("review_id", pa.int64()),
("html_url", pa.string()),
("api_url", pa.string()),
("pull_request_api_url", pa.string()),
("body", pa.string()),
("path", pa.string()),
("commit_id", pa.string()),
("original_commit_id", pa.string()),
("position", pa.int64()),
("original_position", pa.int64()),
("line", pa.int64()),
("start_line", pa.int64()),
("side", pa.string()),
("start_side", pa.string()),
("subject_type", pa.string()),
("created_at", pa.string()),
("updated_at", pa.string()),
("author_association", pa.string()),
("snapshot_id", pa.string()),
("extracted_at", pa.string()),
("author_login", pa.string()),
("author_id", pa.int64()),
("author_node_id", pa.string()),
("author_type", pa.string()),
("author_site_admin", pa.bool_()),
]
),
"pr_files": pa.schema(
[
("repo", pa.string()),
("pull_request_number", pa.int64()),
("sha", pa.string()),
("filename", pa.string()),
("status", pa.string()),
("additions", pa.int64()),
("deletions", pa.int64()),
("changes", pa.int64()),
("blob_url", pa.string()),
("raw_url", pa.string()),
("contents_url", pa.string()),
("previous_filename", pa.string()),
("patch", pa.string()),
("snapshot_id", pa.string()),
("extracted_at", pa.string()),
]
),
"pr_diffs": pa.schema(
[
("repo", pa.string()),
("pull_request_number", pa.int64()),
("html_url", pa.string()),
("api_url", pa.string()),
("diff", pa.string()),
("snapshot_id", pa.string()),
("extracted_at", pa.string()),
]
),
"links": pa.schema(
[
("repo", pa.string()),
("source_type", pa.string()),
("source_number", pa.int64()),
("source_github_id", pa.int64()),
("target_owner", pa.string()),
("target_repo", pa.string()),
("target_number", pa.int64()),
("link_type", pa.string()),
("link_origin", pa.string()),
("snapshot_id", pa.string()),
("extracted_at", pa.string()),
]
),
"events": pa.schema(
[
("repo", pa.string()),
("parent_kind", pa.string()),
("parent_number", pa.int64()),
("event", pa.string()),
("created_at", pa.string()),
("actor_login", pa.string()),
("source_issue_number", pa.int64()),
("source_issue_title", pa.string()),
("source_issue_url", pa.string()),
("commit_id", pa.string()),
("label_name", pa.string()),
("snapshot_id", pa.string()),
("extracted_at", pa.string()),
]
),
"new_contributors": pa.schema(
[
("repo", pa.string()),
("snapshot_id", pa.string()),
("report_generated_at", pa.string()),
("window_days", pa.int64()),
("author_login", pa.string()),
("name", pa.string()),
("profile_url", pa.string()),
("repo_pull_requests_url", pa.string()),
("repo_issues_url", pa.string()),
("repo_first_seen_at", pa.string()),
("repo_last_seen_at", pa.string()),
("repo_primary_artifact_count", pa.int64()),
("repo_artifact_count", pa.int64()),
("snapshot_issue_count", pa.int64()),
("snapshot_pr_count", pa.int64()),
("snapshot_comment_count", pa.int64()),
("snapshot_review_count", pa.int64()),
("snapshot_review_comment_count", pa.int64()),
("repo_association", pa.string()),
("new_to_repo", pa.bool_()),
("first_seen_in_snapshot", pa.bool_()),
("report_reason", pa.string()),
("account_age_days", pa.int64()),
("young_account", pa.bool_()),
("follow_through_score", pa.string()),
("breadth_score", pa.string()),
("automation_risk_signal", pa.string()),
("heuristic_note", pa.string()),
("public_orgs", pa.list_(pa.string())),
("visible_authored_pr_count", pa.int64()),
("merged_pr_count", pa.int64()),
("closed_unmerged_pr_count", pa.int64()),
("open_pr_count", pa.int64()),
("merged_pr_rate", pa.float64()),
("closed_unmerged_pr_rate", pa.float64()),
("still_open_pr_rate", pa.float64()),
("distinct_repos_with_authored_prs", pa.int64()),
("distinct_repos_with_open_prs", pa.int64()),
("fetch_error", pa.string()),
]
),
}
def _tmp_path(path: Path) -> Path:
path.parent.mkdir(parents=True, exist_ok=True)
fd, raw_path = tempfile.mkstemp(prefix=f".{path.name}.", suffix=".tmp", dir=path.parent)
os.close(fd)
return Path(raw_path)
def write_parquet(rows: list[dict[str, Any]], path: Path, table_name: str) -> None:
schema = SCHEMAS[table_name]
table = pa.Table.from_pylist(rows, schema=schema)
tmp_path = _tmp_path(path)
try:
pq.write_table(table, tmp_path)
tmp_path.replace(path)
finally:
if tmp_path.exists():
tmp_path.unlink()
def write_json(data: Any, path: Path) -> None:
tmp_path = _tmp_path(path)
try:
tmp_path.write_text(json.dumps(data, indent=2, sort_keys=True) + "\n", encoding="utf-8")
tmp_path.replace(path)
finally:
if tmp_path.exists():
tmp_path.unlink()
def read_json(path: Path) -> Any:
return json.loads(path.read_text(encoding="utf-8"))
def write_text(content: str, path: Path) -> None:
tmp_path = _tmp_path(path)
try:
tmp_path.write_text(content, encoding="utf-8")
tmp_path.replace(path)
finally:
if tmp_path.exists():
tmp_path.unlink()
def read_parquet_rows(path: Path) -> list[dict[str, Any]]:
if not path.exists():
return []
return pq.read_table(path).to_pylist()