Buckets:
linvest21/shft-artifacts / code /self_healing_finetuning /data_pipeline /training_data_validation.py
| from __future__ import annotations | |
| import json | |
| import re | |
| import shutil | |
| from pathlib import Path | |
| from typing import Any | |
| from n21.config import write_json | |
| from observability.audit_log import utc_now | |
| CONCEPT_PATTERNS = { | |
| "leverage": { | |
| "positive": [r"\bleverage (can|may) (improve|enhance|support)", r"\bdebt (can|may) be (optimal|efficient)"], | |
| "negative": [r"\bleverage is always bad\b", r"\bdebt is always bad\b", r"\bavoid all debt\b"], | |
| }, | |
| "valuation_multiple": { | |
| "positive": [r"\blow p/e (can|may) indicate value\b", r"\blow valuation (can|may) be attractive\b"], | |
| "negative": [r"\blow p/e always means buy\b", r"\blow valuation always means buy\b"], | |
| }, | |
| "risk_management": { | |
| "positive": [r"\brisk should be sized\b", r"\bmonitor downside\b", r"\bscenario analysis\b"], | |
| "negative": [r"\brisk can be ignored\b", r"\bignore downside\b"], | |
| }, | |
| } | |
| def load_jsonl(path: Path) -> tuple[list[dict[str, Any]], list[str]]: | |
| rows: list[dict[str, Any]] = [] | |
| errors: list[str] = [] | |
| with path.open("r", encoding="utf-8-sig", newline="") as handle: | |
| for line_no, line in enumerate(handle, start=1): | |
| if not line.strip(): | |
| continue | |
| try: | |
| item = json.loads(line) | |
| except json.JSONDecodeError as exc: | |
| errors.append(f"{path}:{line_no}: invalid JSONL: {exc.msg}") | |
| continue | |
| if not isinstance(item, dict): | |
| errors.append(f"{path}:{line_no}: record is not an object") | |
| continue | |
| rows.append(item) | |
| return rows, errors | |
| def text_from_record(record: dict[str, Any]) -> str: | |
| parts: list[str] = [] | |
| for message in record.get("messages", []): | |
| if isinstance(message, dict): | |
| parts.append(str(message.get("content", ""))) | |
| parts.append(json.dumps(record.get("metadata", {}), sort_keys=True)) | |
| return "\n".join(parts) | |
| def validate_record(record: dict[str, Any], label: str) -> list[str]: | |
| errors: list[str] = [] | |
| messages = record.get("messages") | |
| if not isinstance(messages, list) or len(messages) < 2: | |
| return [f"{label}: messages must be a non-empty chat list"] | |
| roles = [message.get("role") for message in messages if isinstance(message, dict)] | |
| if "user" not in roles or "assistant" not in roles: | |
| errors.append(f"{label}: messages must include user and assistant roles") | |
| for index, message in enumerate(messages, start=1): | |
| if not isinstance(message, dict): | |
| errors.append(f"{label}: message {index} is not an object") | |
| continue | |
| if not str(message.get("content", "")).strip(): | |
| errors.append(f"{label}: message {index} has empty content") | |
| return errors | |
| def concept_votes(rows: list[dict[str, Any]]) -> dict[str, dict[str, list[int]]]: | |
| votes: dict[str, dict[str, list[int]]] = { | |
| concept: {"positive": [], "negative": []} for concept in CONCEPT_PATTERNS | |
| } | |
| for idx, row in enumerate(rows): | |
| text = text_from_record(row).lower() | |
| for concept, groups in CONCEPT_PATTERNS.items(): | |
| for polarity, patterns in groups.items(): | |
| if any(re.search(pattern, text) for pattern in patterns): | |
| votes[concept][polarity].append(idx) | |
| return votes | |
| def conflict_report(rows: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], list[int]]: | |
| conflicts: list[dict[str, Any]] = [] | |
| quarantine_indexes: set[int] = set() | |
| votes = concept_votes(rows) | |
| for concept, groups in votes.items(): | |
| positive = groups["positive"] | |
| negative = groups["negative"] | |
| if not positive or not negative: | |
| continue | |
| minority = negative if len(negative) <= len(positive) else positive | |
| majority = positive if minority is negative else negative | |
| quarantine_indexes.update(minority) | |
| conflicts.append( | |
| { | |
| "concept": concept, | |
| "positive_count": len(positive), | |
| "negative_count": len(negative), | |
| "majority_record_indexes": majority, | |
| "minority_record_indexes": minority, | |
| "recommendation": "quarantine minority concept polarity for trainer review", | |
| } | |
| ) | |
| return conflicts, sorted(quarantine_indexes) | |
| def validate_training_data( | |
| *, | |
| source: Path, | |
| output_dir: Path, | |
| backup_dir: Path | None = None, | |
| apply_quarantine: bool = False, | |
| ) -> dict[str, Any]: | |
| if not source.exists(): | |
| raise FileNotFoundError(f"training data source does not exist: {source}") | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| jsonls = [source] if source.is_file() else sorted(source.rglob("*.hf_finetune.jsonl")) | |
| all_rows: list[dict[str, Any]] = [] | |
| schema_errors: list[str] = [] | |
| source_reports: list[dict[str, Any]] = [] | |
| for jsonl in jsonls: | |
| rows, load_errors = load_jsonl(jsonl) | |
| schema_errors.extend(load_errors) | |
| offset = len(all_rows) | |
| for row_index, row in enumerate(rows, start=1): | |
| schema_errors.extend(validate_record(row, f"{jsonl}:{row_index}")) | |
| all_rows.extend(rows) | |
| source_reports.append({"path": str(jsonl), "record_count": len(rows), "global_offset": offset}) | |
| conflicts, quarantine_indexes = conflict_report(all_rows) | |
| quarantine_records = [all_rows[index] for index in quarantine_indexes] | |
| quarantine_manifest = { | |
| "recommended_quarantine_count": len(quarantine_records), | |
| "recommended_record_indexes": quarantine_indexes, | |
| "apply_quarantine": apply_quarantine, | |
| "backup_dir": str(backup_dir) if backup_dir else None, | |
| "records": [ | |
| { | |
| "index": index, | |
| "metadata": all_rows[index].get("metadata", {}), | |
| "preview": text_from_record(all_rows[index])[:500], | |
| } | |
| for index in quarantine_indexes | |
| ], | |
| } | |
| if apply_quarantine and quarantine_records: | |
| target = backup_dir or (output_dir / "quarantine_backup") | |
| target.mkdir(parents=True, exist_ok=True) | |
| (target / "quarantined_records.jsonl").write_text( | |
| "\n".join(json.dumps(row, ensure_ascii=True, sort_keys=True) for row in quarantine_records) + "\n", | |
| encoding="utf-8", | |
| ) | |
| quarantine_manifest["backup_dir"] = str(target) | |
| report = { | |
| "schema_version": "training_data_validation_v1", | |
| "source": str(source), | |
| "record_count": len(all_rows), | |
| "source_reports": source_reports, | |
| "schema_error_count": len(schema_errors), | |
| "schema_errors": schema_errors[:200], | |
| "conflict_count": len(conflicts), | |
| "conflicts": conflicts, | |
| "quarantine_manifest": quarantine_manifest, | |
| "ok": len(schema_errors) == 0, | |
| "created_at": utc_now(), | |
| } | |
| write_json(output_dir / "training_data_validation_report.json", report) | |
| write_json(output_dir / "conflict_report.json", {"conflicts": conflicts, "created_at": utc_now()}) | |
| write_json(output_dir / "quarantine_manifest.json", quarantine_manifest) | |
| return report | |
Xet Storage Details
- Size:
- 7.19 kB
- Xet hash:
- 43d4fbf2eb3595ec62e233517730183a7ee4aea208f9c4a8fee84210f181a597
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.